octopi 1.1__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of octopi might be problematic. Click here for more details.
- octopi/__init__.py +1 -0
- octopi/datasets/cached_datset.py +1 -1
- octopi/datasets/generators.py +1 -1
- octopi/datasets/io.py +200 -0
- octopi/datasets/multi_config_generator.py +1 -1
- octopi/entry_points/common.py +5 -5
- octopi/entry_points/create_slurm_submission.py +1 -1
- octopi/entry_points/run_create_targets.py +6 -6
- octopi/entry_points/run_evaluate.py +4 -3
- octopi/entry_points/run_extract_mb_picks.py +5 -5
- octopi/entry_points/run_localize.py +8 -9
- octopi/entry_points/run_optuna.py +7 -7
- octopi/entry_points/run_segment_predict.py +4 -4
- octopi/entry_points/run_train.py +7 -8
- octopi/extract/localize.py +11 -19
- octopi/extract/membranebound_extract.py +11 -10
- octopi/extract/midpoint_extract.py +3 -3
- octopi/models/common.py +1 -1
- octopi/processing/create_targets_from_picks.py +3 -4
- octopi/processing/evaluate.py +24 -11
- octopi/processing/importers.py +4 -4
- octopi/pytorch/hyper_search.py +2 -3
- octopi/pytorch/model_search_submitter.py +4 -4
- octopi/pytorch/segmentation.py +141 -190
- octopi/pytorch/segmentation_multigpu.py +162 -0
- octopi/pytorch/trainer.py +2 -2
- octopi/utils/__init__.py +0 -0
- octopi/utils/config.py +57 -0
- octopi/utils/io.py +128 -0
- octopi/{utils.py → utils/parsers.py} +10 -84
- octopi/{stopping_criteria.py → utils/stopping_criteria.py} +3 -3
- octopi/{visualization_tools.py → utils/visualization_tools.py} +4 -4
- octopi/workflows.py +236 -0
- {octopi-1.1.dist-info → octopi-1.2.0.dist-info}/METADATA +41 -29
- octopi-1.2.0.dist-info/RECORD +62 -0
- {octopi-1.1.dist-info → octopi-1.2.0.dist-info}/WHEEL +1 -1
- octopi-1.2.0.dist-info/entry_points.txt +3 -0
- {octopi-1.1.dist-info → octopi-1.2.0.dist-info/licenses}/LICENSE +3 -3
- octopi/io.py +0 -457
- octopi/processing/my_metrics.py +0 -26
- octopi/processing/writers.py +0 -102
- octopi-1.1.dist-info/RECORD +0 -59
- octopi-1.1.dist-info/entry_points.txt +0 -4
- /octopi/{losses.py → utils/losses.py} +0 -0
- /octopi/{submit_slurm.py → utils/submit_slurm.py} +0 -0
octopi/__init__.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "1.2.0"
|
octopi/datasets/cached_datset.py
CHANGED
octopi/datasets/generators.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from octopi.datasets import dataset, augment, cached_datset
|
|
2
2
|
from monai.data import DataLoader, SmartCacheDataset, CacheDataset, Dataset
|
|
3
3
|
from typing import List, Optional
|
|
4
|
-
from octopi import io
|
|
4
|
+
from octopi.datasets import io
|
|
5
5
|
import torch, os, random, gc
|
|
6
6
|
import multiprocess as mp
|
|
7
7
|
|
octopi/datasets/io.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Data loading, processing, and dataset operations for the datasets module.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from monai.data import DataLoader, CacheDataset, Dataset
|
|
6
|
+
from monai.transforms import (
|
|
7
|
+
Compose,
|
|
8
|
+
NormalizeIntensityd,
|
|
9
|
+
EnsureChannelFirstd,
|
|
10
|
+
)
|
|
11
|
+
from sklearn.model_selection import train_test_split
|
|
12
|
+
from collections import defaultdict
|
|
13
|
+
from copick_utils.io import readers
|
|
14
|
+
import copick, torch, os, random
|
|
15
|
+
from typing import List
|
|
16
|
+
from tqdm import tqdm
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def load_training_data(root,
|
|
20
|
+
runIDs: List[str],
|
|
21
|
+
voxel_spacing: float,
|
|
22
|
+
tomo_algorithm: str,
|
|
23
|
+
segmenation_name: str,
|
|
24
|
+
segmentation_session_id: str = None,
|
|
25
|
+
segmentation_user_id: str = None,
|
|
26
|
+
progress_update: bool = True):
|
|
27
|
+
"""
|
|
28
|
+
Load training data from CoPick runs.
|
|
29
|
+
"""
|
|
30
|
+
data_dicts = []
|
|
31
|
+
# Use tqdm for progress tracking only if progress_update is True
|
|
32
|
+
iterable = tqdm(runIDs, desc="Loading Training Data") if progress_update else runIDs
|
|
33
|
+
for runID in iterable:
|
|
34
|
+
run = root.get_run(str(runID))
|
|
35
|
+
tomogram = readers.tomogram(run, voxel_spacing, tomo_algorithm)
|
|
36
|
+
segmentation = readers.segmentation(run,
|
|
37
|
+
voxel_spacing,
|
|
38
|
+
segmenation_name,
|
|
39
|
+
segmentation_session_id,
|
|
40
|
+
segmentation_user_id)
|
|
41
|
+
data_dicts.append({"image": tomogram, "label": segmentation})
|
|
42
|
+
|
|
43
|
+
return data_dicts
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def load_predict_data(root,
|
|
47
|
+
runIDs: List[str],
|
|
48
|
+
voxel_spacing: float,
|
|
49
|
+
tomo_algorithm: str):
|
|
50
|
+
"""
|
|
51
|
+
Load prediction data from CoPick runs.
|
|
52
|
+
"""
|
|
53
|
+
data_dicts = []
|
|
54
|
+
for runID in tqdm(runIDs):
|
|
55
|
+
run = root.get_run(str(runID))
|
|
56
|
+
tomogram = readers.tomogram(run, voxel_spacing, tomo_algorithm)
|
|
57
|
+
data_dicts.append({"image": tomogram})
|
|
58
|
+
|
|
59
|
+
return data_dicts
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def create_predict_dataloader(
|
|
63
|
+
root,
|
|
64
|
+
voxel_spacing: float,
|
|
65
|
+
tomo_algorithm: str,
|
|
66
|
+
runIDs: str = None,
|
|
67
|
+
):
|
|
68
|
+
"""
|
|
69
|
+
Create a dataloader for prediction data.
|
|
70
|
+
"""
|
|
71
|
+
# define pre transforms
|
|
72
|
+
pre_transforms = Compose(
|
|
73
|
+
[ EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
|
|
74
|
+
NormalizeIntensityd(keys=["image"]),
|
|
75
|
+
])
|
|
76
|
+
|
|
77
|
+
# Split trainRunIDs, validateRunIDs, testRunIDs
|
|
78
|
+
if runIDs is None:
|
|
79
|
+
runIDs = [run.name for run in root.runs]
|
|
80
|
+
test_files = load_predict_data(root, runIDs, voxel_spacing, tomo_algorithm)
|
|
81
|
+
|
|
82
|
+
bs = min( len(test_files), 4)
|
|
83
|
+
test_ds = CacheDataset(data=test_files, transform=pre_transforms)
|
|
84
|
+
test_loader = DataLoader(test_ds,
|
|
85
|
+
batch_size=bs,
|
|
86
|
+
shuffle=False,
|
|
87
|
+
num_workers=4,
|
|
88
|
+
pin_memory=torch.cuda.is_available())
|
|
89
|
+
return test_loader, test_ds
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def adjust_to_multiple(value, multiple = 16):
|
|
93
|
+
"""
|
|
94
|
+
Adjust a value to be a multiple of the specified number.
|
|
95
|
+
"""
|
|
96
|
+
return int((value // multiple) * multiple)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def get_input_dimensions(dataset, crop_size: int):
|
|
100
|
+
"""
|
|
101
|
+
Get input dimensions for the dataset.
|
|
102
|
+
"""
|
|
103
|
+
nx = dataset[0]['image'].shape[1]
|
|
104
|
+
if crop_size > nx:
|
|
105
|
+
first_dim = adjust_to_multiple(nx/2)
|
|
106
|
+
return first_dim, crop_size, crop_size
|
|
107
|
+
else:
|
|
108
|
+
return crop_size, crop_size, crop_size
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def get_num_classes(copick_config_path: str):
|
|
112
|
+
"""
|
|
113
|
+
Get the number of classes from a CoPick configuration.
|
|
114
|
+
"""
|
|
115
|
+
root = copick.from_file(copick_config_path)
|
|
116
|
+
return len(root.pickable_objects) + 1
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def split_multiclass_dataset(runIDs,
|
|
120
|
+
train_ratio: float = 0.7,
|
|
121
|
+
val_ratio: float = 0.15,
|
|
122
|
+
test_ratio: float = 0.15,
|
|
123
|
+
return_test_dataset: bool = True,
|
|
124
|
+
random_state: int = 42):
|
|
125
|
+
"""
|
|
126
|
+
Splits a given dataset into three subsets: training, validation, and testing. If the dataset
|
|
127
|
+
has categories (as tuples), splits are balanced across all categories. If the dataset is a 1D
|
|
128
|
+
list, it is split without categorization.
|
|
129
|
+
|
|
130
|
+
Parameters:
|
|
131
|
+
- runIDs: A list of items to split. It can be a 1D list or a list of tuples (category, value).
|
|
132
|
+
- train_ratio: Proportion of the dataset for training.
|
|
133
|
+
- val_ratio: Proportion of the dataset for validation.
|
|
134
|
+
- test_ratio: Proportion of the dataset for testing.
|
|
135
|
+
- return_test_dataset: Whether to return the test dataset.
|
|
136
|
+
- random_state: Random state for reproducibility.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
- trainRunIDs: Training subset.
|
|
140
|
+
- valRunIDs: Validation subset.
|
|
141
|
+
- testRunIDs: Testing subset (if return_test_dataset is True, otherwise None).
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
# Ensure the ratios add up to 1
|
|
145
|
+
assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must add up to 1.0"
|
|
146
|
+
|
|
147
|
+
# Check if the dataset has categories
|
|
148
|
+
if isinstance(runIDs[0], tuple) and len(runIDs[0]) == 2:
|
|
149
|
+
# Group by category
|
|
150
|
+
grouped = defaultdict(list)
|
|
151
|
+
for item in runIDs:
|
|
152
|
+
grouped[item[0]].append(item)
|
|
153
|
+
|
|
154
|
+
# Split each category
|
|
155
|
+
trainRunIDs, valRunIDs, testRunIDs = [], [], []
|
|
156
|
+
for category, items in grouped.items():
|
|
157
|
+
# Shuffle for randomness
|
|
158
|
+
random.shuffle(items)
|
|
159
|
+
# Split into train and remaining
|
|
160
|
+
train_items, remaining = train_test_split(items, test_size=(1 - train_ratio), random_state=random_state)
|
|
161
|
+
trainRunIDs.extend(train_items)
|
|
162
|
+
|
|
163
|
+
if return_test_dataset:
|
|
164
|
+
# Split remaining into validation and test
|
|
165
|
+
val_items, test_items = train_test_split(
|
|
166
|
+
remaining,
|
|
167
|
+
test_size=(test_ratio / (val_ratio + test_ratio)),
|
|
168
|
+
random_state=random_state,
|
|
169
|
+
)
|
|
170
|
+
valRunIDs.extend(val_items)
|
|
171
|
+
testRunIDs.extend(test_items)
|
|
172
|
+
else:
|
|
173
|
+
valRunIDs.extend(remaining)
|
|
174
|
+
testRunIDs = []
|
|
175
|
+
else:
|
|
176
|
+
# If no categories, split as a 1D list
|
|
177
|
+
trainRunIDs, remaining = train_test_split(runIDs, test_size=(1 - train_ratio), random_state=random_state)
|
|
178
|
+
if return_test_dataset:
|
|
179
|
+
valRunIDs, testRunIDs = train_test_split(
|
|
180
|
+
remaining,
|
|
181
|
+
test_size=(test_ratio / (val_ratio + test_ratio)),
|
|
182
|
+
random_state=random_state,
|
|
183
|
+
)
|
|
184
|
+
else:
|
|
185
|
+
valRunIDs = remaining
|
|
186
|
+
testRunIDs = []
|
|
187
|
+
|
|
188
|
+
return trainRunIDs, valRunIDs, testRunIDs
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def load_copick_config(path: str):
|
|
192
|
+
"""
|
|
193
|
+
Load a CoPick configuration from file.
|
|
194
|
+
"""
|
|
195
|
+
if os.path.isfile(path):
|
|
196
|
+
root = copick.from_file(path)
|
|
197
|
+
else:
|
|
198
|
+
raise FileNotFoundError(f"Copick Config Path does not exist: {path}")
|
|
199
|
+
|
|
200
|
+
return root
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from octopi.datasets import dataset, augment, cached_datset
|
|
2
2
|
from octopi.datasets.generators import TrainLoaderManager
|
|
3
3
|
from monai.data import DataLoader, SmartCacheDataset, CacheDataset, Dataset
|
|
4
|
-
from octopi import io
|
|
4
|
+
from octopi.datasets import io
|
|
5
5
|
import multiprocess as mp
|
|
6
6
|
from typing import List
|
|
7
7
|
from tqdm import tqdm
|
octopi/entry_points/common.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from octopi import
|
|
1
|
+
from octopi.utils import parsers
|
|
2
2
|
import argparse
|
|
3
3
|
|
|
4
4
|
def add_model_parameters(parser, octopi = False):
|
|
@@ -8,8 +8,8 @@ def add_model_parameters(parser, octopi = False):
|
|
|
8
8
|
|
|
9
9
|
# Add U-Net model parameters
|
|
10
10
|
parser.add_argument("--Nclass", type=int, required=False, default=3, help="Number of prediction classes in the model")
|
|
11
|
-
parser.add_argument("--channels", type=
|
|
12
|
-
parser.add_argument("--strides", type=
|
|
11
|
+
parser.add_argument("--channels", type=parsers.parse_int_list, required=False, default='32,64,96,96', help="List of channel sizes")
|
|
12
|
+
parser.add_argument("--strides", type=parsers.parse_int_list, required=False, default='2,2,1', help="List of stride sizes")
|
|
13
13
|
parser.add_argument("--res-units", type=int, required=False, default=1, help="Number of residual units in the UNet")
|
|
14
14
|
parser.add_argument("--dim-in", type=int, required=False, default=96, help="Input dimension for the UNet model")
|
|
15
15
|
|
|
@@ -52,11 +52,11 @@ def add_inference_parameters(parser):
|
|
|
52
52
|
|
|
53
53
|
parser.add_argument("--tomo-alg", required=False, default = 'wbp',
|
|
54
54
|
help="Tomogram algorithm used for produces segmentation prediction masks.")
|
|
55
|
-
parser.add_argument("--seg-info", type=
|
|
55
|
+
parser.add_argument("--seg-info", type=parsers.parse_target, required=False,
|
|
56
56
|
default='predict,octopi,1', help='Information Query to save Segmentation predictions under, e.g., (e.g., "name" or "name,user_id,session_id" - Default UserID is octopi and SessionID is 1')
|
|
57
57
|
parser.add_argument("--tomo-batch-size", type=int, default=25, required=False,
|
|
58
58
|
help="Batch size for tomogram processing.")
|
|
59
|
-
parser.add_argument("--run-ids", type=
|
|
59
|
+
parser.add_argument("--run-ids", type=parsers.parse_list, default=None, required=False,
|
|
60
60
|
help="List of run IDs for prediction, e.g., run1,run2 or [run1,run2]. If not provided, all available runs will be processed.")
|
|
61
61
|
|
|
62
62
|
def add_localize_parameters(parser):
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from octopi.entry_points import run_train, run_segment_predict, run_localize, run_optuna
|
|
2
|
-
from octopi.submit_slurm import create_shellsubmit, create_multiconfig_shellsubmit
|
|
2
|
+
from octopi.utils.submit_slurm import create_shellsubmit, create_multiconfig_shellsubmit
|
|
3
3
|
from octopi.processing.importers import cli_mrcs_parser, cli_dataportal_parser
|
|
4
4
|
from octopi.entry_points import common
|
|
5
5
|
from octopi import utils
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import octopi.processing.create_targets_from_picks as create_targets
|
|
2
2
|
from typing import List, Tuple, Union
|
|
3
|
+
from octopi.utils import io, parsers
|
|
3
4
|
from collections import defaultdict
|
|
4
5
|
import argparse, copick, yaml, os
|
|
5
|
-
from octopi import utils, io
|
|
6
6
|
from tqdm import tqdm
|
|
7
7
|
import numpy as np
|
|
8
8
|
|
|
@@ -160,16 +160,16 @@ def parse_args():
|
|
|
160
160
|
|
|
161
161
|
input_group = parser.add_argument_group("Input Arguments")
|
|
162
162
|
input_group.add_argument("--config", type=str, required=True, help="Path to the CoPick configuration file.")
|
|
163
|
-
input_group.add_argument("--target", type=
|
|
163
|
+
input_group.add_argument("--target", type=parsers.parse_target, action="append", default=None, help='Target specifications: "name" or "name,user_id,session_id".')
|
|
164
164
|
input_group.add_argument("--picks-session-id", type=str, default=None, help="Session ID for the picks.")
|
|
165
165
|
input_group.add_argument("--picks-user-id", type=str, default=None, help="User ID associated with the picks.")
|
|
166
|
-
input_group.add_argument("--seg-target", type=
|
|
167
|
-
input_group.add_argument("--run-ids", type=
|
|
166
|
+
input_group.add_argument("--seg-target", type=parsers.parse_target, action="append", default=[], help='Segmentation targets: "name" or "name,user_id,session_id".')
|
|
167
|
+
input_group.add_argument("--run-ids", type=parsers.parse_list, default=None, help="List of run IDs.")
|
|
168
168
|
|
|
169
169
|
# Parameters
|
|
170
170
|
parameters_group = parser.add_argument_group("Parameters")
|
|
171
171
|
parameters_group.add_argument("--tomo-alg", type=str, default="wbp", help="Tomogram reconstruction algorithm.")
|
|
172
|
-
parameters_group.add_argument("--radius-scale", type=float, default=0.
|
|
172
|
+
parameters_group.add_argument("--radius-scale", type=float, default=0.7, help="Scale factor for object radius.")
|
|
173
173
|
parameters_group.add_argument("--voxel-size", type=float, default=10, help="Voxel size for tomogram reconstruction.")
|
|
174
174
|
|
|
175
175
|
output_group = parser.add_argument_group("Output Arguments")
|
|
@@ -275,7 +275,7 @@ def save_parameters(args, output_path: str):
|
|
|
275
275
|
existing_data[input_key] = new_entry[input_key]
|
|
276
276
|
|
|
277
277
|
# Save back to the YAML file
|
|
278
|
-
|
|
278
|
+
io.save_parameters_yaml(existing_data, output_path)
|
|
279
279
|
|
|
280
280
|
if __name__ == "__main__":
|
|
281
281
|
cli()
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import octopi.processing.evaluate as evaluate
|
|
2
|
-
|
|
2
|
+
from octopi.utils import parsers
|
|
3
3
|
from typing import List
|
|
4
4
|
import argparse
|
|
5
5
|
|
|
@@ -31,6 +31,7 @@ def cli():
|
|
|
31
31
|
"""
|
|
32
32
|
CLI entry point for running evaluation.
|
|
33
33
|
"""
|
|
34
|
+
|
|
34
35
|
parser = argparse.ArgumentParser(
|
|
35
36
|
description='Run evaluation on pick and place predictions.',
|
|
36
37
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
@@ -43,8 +44,8 @@ def cli():
|
|
|
43
44
|
parser.add_argument('--predict-session-id', type=str, required=False, default= None, help='Session ID for prediction data')
|
|
44
45
|
parser.add_argument('--save-path', type=str, required=False, default= None, help='Path to save evaluation results')
|
|
45
46
|
parser.add_argument('--distance-threshold-scale', type=float, required=False, default = 0.8, help='Compute Distance Threshold Based on Particle Radius')
|
|
46
|
-
parser.add_argument('--object-names', type=
|
|
47
|
-
parser.add_argument('--run-ids', type=
|
|
47
|
+
parser.add_argument('--object-names', type=parsers.parse_list, default=None, required=False, help='Optional list of object names to evaluate, e.g., ribosome,apoferritin or [ribosome,apoferritin].')
|
|
48
|
+
parser.add_argument('--run-ids', type=parsers.parse_list, default=None, required=False, help='Optional list of run IDs to evaluate, e.g., run1,run2,run3 or [run1,run2,run3].')
|
|
48
49
|
|
|
49
50
|
args = parser.parse_args()
|
|
50
51
|
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from octopi.extract import membranebound_extract as extract
|
|
2
|
-
from octopi import
|
|
2
|
+
from octopi.utils import parsers
|
|
3
3
|
import argparse, json, pprint, copick, json
|
|
4
4
|
from typing import List, Tuple, Optional
|
|
5
5
|
import multiprocess as mp
|
|
@@ -58,12 +58,12 @@ def cli():
|
|
|
58
58
|
parser.add_argument('--config', type=str, required=True, help='Path to the configuration file.')
|
|
59
59
|
parser.add_argument('--voxel-size', type=float, required=False, default=10, help='Voxel size.')
|
|
60
60
|
parser.add_argument('--distance-threshold', type=float, required=False, default=10, help='Distance threshold.')
|
|
61
|
-
parser.add_argument('--picks-info', type=
|
|
62
|
-
parser.add_argument('--membrane-info', type=
|
|
63
|
-
parser.add_argument('--organelle-info', type=
|
|
61
|
+
parser.add_argument('--picks-info', type=parsers.parse_target, required=True, help='Query for the picks (e.g., "name" or "name,user_id,session_id".).')
|
|
62
|
+
parser.add_argument('--membrane-info', type=parsers.parse_target, required=False, help='Query for the membrane segmentation (e.g., "name" or "name,user_id,session_id".).')
|
|
63
|
+
parser.add_argument('--organelle-info', type=parsers.parse_target, required=False, help='Query for the organelles segmentations (e.g., "name" or "name,user_id,session_id".).')
|
|
64
64
|
parser.add_argument('--save-user-id', type=str, required=False, default=None, help='User ID to save the new picks.')
|
|
65
65
|
parser.add_argument('--save-session-id', type=str, required=True, help='Session ID to save the new picks.')
|
|
66
|
-
parser.add_argument('--runIDs', type=
|
|
66
|
+
parser.add_argument('--runIDs', type=parsers.parse_list, required=False, help='List of run IDs to process.')
|
|
67
67
|
parser.add_argument('--n-procs', type=int, required=False, default=None, help='Number of processes to use.')
|
|
68
68
|
|
|
69
69
|
args = parser.parse_args()
|
|
@@ -1,11 +1,10 @@
|
|
|
1
1
|
from octopi.entry_points import common
|
|
2
|
+
from octopi.utils import parsers, io
|
|
2
3
|
from octopi.extract import localize
|
|
3
|
-
from octopi import utils
|
|
4
4
|
import copick, argparse, pprint
|
|
5
5
|
from typing import List, Tuple
|
|
6
6
|
import multiprocess as mp
|
|
7
7
|
from tqdm import tqdm
|
|
8
|
-
import os
|
|
9
8
|
|
|
10
9
|
def pick_particles(
|
|
11
10
|
copick_config_path: str,
|
|
@@ -53,13 +52,13 @@ def pick_particles(
|
|
|
53
52
|
|
|
54
53
|
# Nprocesses shouldnt exceed computation resource or number of available runs
|
|
55
54
|
n_run_ids = len(run_ids)
|
|
56
|
-
n_procs = min(mp.
|
|
55
|
+
n_procs = min(mp.cpu_count(), n_procs, n_run_ids)
|
|
57
56
|
|
|
58
57
|
# Run Localization - Main Parallelization Loop
|
|
59
58
|
print(f"Using {n_procs} processes to parallelize across {n_run_ids} run IDs.")
|
|
60
59
|
with mp.Pool(processes=n_procs) as pool:
|
|
61
60
|
with tqdm(total=n_run_ids, desc="Localization", unit="run") as pbar:
|
|
62
|
-
worker_func = lambda run_id: localize.
|
|
61
|
+
worker_func = lambda run_id: localize.process_localization(
|
|
63
62
|
root.get_run(run_id),
|
|
64
63
|
objects,
|
|
65
64
|
seg_info,
|
|
@@ -85,20 +84,20 @@ def localize_parser(parser_description, add_slurm: bool = False):
|
|
|
85
84
|
input_group = parser.add_argument_group("Input Arguments")
|
|
86
85
|
input_group.add_argument("--config", type=str, required=True, help="Path to the CoPick configuration file.")
|
|
87
86
|
input_group.add_argument("--method", type=str, choices=['watershed', 'com'], default='watershed', required=False, help="Localization method to use.")
|
|
88
|
-
input_group.add_argument('--seg-info', type=
|
|
87
|
+
input_group.add_argument('--seg-info', type=parsers.parse_target, required=False, default='predict,octopi,1', help='Query for the organelles segmentations (e.g., "name" or "name,user_id,session_id".).')
|
|
89
88
|
input_group.add_argument("--voxel-size", type=float, default=10, required=False, help="Voxel size for localization.")
|
|
90
|
-
input_group.add_argument("--runIDs", type=
|
|
89
|
+
input_group.add_argument("--runIDs", type=parsers.parse_list, default = None, required=False, help="List of runIDs to run inference on, e.g., run1,run2,run3 or [run1,run2,run3].")
|
|
91
90
|
|
|
92
91
|
localize_group = parser.add_argument_group("Localize Arguments")
|
|
93
92
|
localize_group.add_argument("--radius-min-scale", type=float, default=0.5, required=False, help="Minimum radius scale for particles.")
|
|
94
93
|
localize_group.add_argument("--radius-max-scale", type=float, default=1.0, required=False, help="Maximum radius scale for particles.")
|
|
95
94
|
localize_group.add_argument("--filter-size", type=int, default=10, required=False, help="Filter size for localization.")
|
|
96
|
-
localize_group.add_argument("--pick-objects", type=
|
|
95
|
+
localize_group.add_argument("--pick-objects", type=parsers.parse_list, default=None, required=False, help="Specific Objects to Find Picks for.")
|
|
97
96
|
localize_group.add_argument("--n-procs", type=int, default=8, required=False, help="Number of CPU processes to parallelize runs across. Defaults to the max number of cores available or available runs.")
|
|
98
97
|
|
|
99
98
|
output_group = parser.add_argument_group("Output Arguments")
|
|
100
99
|
output_group.add_argument("--pick-session-id", type=str, default='1', required=False, help="Session ID for the particle picks.")
|
|
101
|
-
output_group.add_argument("--pick-user-id", type=str, default='
|
|
100
|
+
output_group.add_argument("--pick-user-id", type=str, default='octopi', required=False, help="User ID for the particle picks.")
|
|
102
101
|
|
|
103
102
|
if add_slurm:
|
|
104
103
|
slurm_group = parser.add_argument_group("SLURM Arguments")
|
|
@@ -165,7 +164,7 @@ def save_parameters(args: argparse.Namespace,
|
|
|
165
164
|
pprint.pprint(params); print()
|
|
166
165
|
|
|
167
166
|
# Save to YAML file
|
|
168
|
-
|
|
167
|
+
io.save_parameters_yaml(params, output_path)
|
|
169
168
|
|
|
170
169
|
if __name__ == "__main__":
|
|
171
170
|
cli()
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from octopi.pytorch.model_search_submitter import ModelSearchSubmit
|
|
2
2
|
from octopi.entry_points import common
|
|
3
|
+
from octopi.utils import parsers, io
|
|
3
4
|
import argparse, os, pprint
|
|
4
|
-
from octopi import utils
|
|
5
5
|
|
|
6
6
|
def optuna_parser(parser_description, add_slurm: bool = False):
|
|
7
7
|
"""
|
|
@@ -20,22 +20,22 @@ def optuna_parser(parser_description, add_slurm: bool = False):
|
|
|
20
20
|
# Input Arguments
|
|
21
21
|
input_group = parser.add_argument_group("Input Arguments")
|
|
22
22
|
common.add_config(input_group, single_config=False)
|
|
23
|
-
input_group.add_argument("--target-info", type=
|
|
23
|
+
input_group.add_argument("--target-info", type=parsers.parse_target, default="targets,octopi,1",
|
|
24
24
|
help="Target information, e.g., 'name' or 'name,user_id,session_id'")
|
|
25
25
|
input_group.add_argument("--tomo-alg", default='wbp',
|
|
26
26
|
help="Tomogram algorithm used for training")
|
|
27
27
|
input_group.add_argument("--mlflow-experiment-name", type=str, default="model-search", required=False,
|
|
28
28
|
help="Name of the MLflow experiment (default: 'model-search').")
|
|
29
|
-
input_group.add_argument("--trainRunIDs", type=
|
|
29
|
+
input_group.add_argument("--trainRunIDs", type=parsers.parse_list, default=None, required=False,
|
|
30
30
|
help="List of training run IDs, e.g., run1,run2 or [run1,run2].")
|
|
31
|
-
input_group.add_argument("--validateRunIDs", type=
|
|
31
|
+
input_group.add_argument("--validateRunIDs", type=parsers.parse_list, default=None, required=False,
|
|
32
32
|
help="List of validation run IDs, e.g., run3,run4 or [run3,run4].")
|
|
33
33
|
input_group.add_argument('--data-split', type=str, default='0.8', help="Data split ratios. Either a single value (e.g., '0.8' for 80/20/0 split) "
|
|
34
34
|
"or two comma-separated values (e.g., '0.7,0.1' for 70/10/20 split)")
|
|
35
35
|
|
|
36
36
|
model_group = parser.add_argument_group("Model Arguments")
|
|
37
37
|
model_group.add_argument("--model-type", type=str, default='Unet', required=False,
|
|
38
|
-
choices=['Unet', 'AttentionUnet'],
|
|
38
|
+
choices=['Unet', 'AttentionUnet', 'MedNeXt', 'SegResNet'],
|
|
39
39
|
help="Model type to use for training")
|
|
40
40
|
model_group.add_argument("--Nclass", type=int, default=3, required=False, help="Number of prediction classes in the model")
|
|
41
41
|
|
|
@@ -61,7 +61,7 @@ def cli():
|
|
|
61
61
|
args = optuna_parser(description)
|
|
62
62
|
|
|
63
63
|
# Parse the CoPick configuration paths
|
|
64
|
-
if len(args.config) > 1: copick_configs =
|
|
64
|
+
if len(args.config) > 1: copick_configs = parsers.parse_copick_configs(args.config)
|
|
65
65
|
else: copick_configs = args.config[0]
|
|
66
66
|
|
|
67
67
|
# Create the model exploration directory
|
|
@@ -133,7 +133,7 @@ def save_parameters(args: argparse.Namespace,
|
|
|
133
133
|
pprint.pprint(params); print()
|
|
134
134
|
|
|
135
135
|
# Save to YAML file
|
|
136
|
-
|
|
136
|
+
io.save_parameters_yaml(params, output_path)
|
|
137
137
|
|
|
138
138
|
if __name__ == "__main__":
|
|
139
139
|
cli()
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
+
import torch, argparse, json, pprint, yaml, os
|
|
1
2
|
from octopi.pytorch import segmentation
|
|
2
3
|
from octopi.entry_points import common
|
|
3
|
-
import torch, argparse, json, pprint, yaml, os
|
|
4
|
-
from octopi import utils
|
|
5
4
|
from typing import List, Tuple
|
|
5
|
+
from octopi.utils import io
|
|
6
6
|
|
|
7
7
|
def inference(
|
|
8
8
|
copick_config_path: str,
|
|
@@ -136,7 +136,7 @@ def save_parameters(args: argparse.Namespace,
|
|
|
136
136
|
output_path: str):
|
|
137
137
|
|
|
138
138
|
# Load the model config
|
|
139
|
-
model_config =
|
|
139
|
+
model_config = io.load_yaml(args.model_config)
|
|
140
140
|
|
|
141
141
|
# Create parameters dictionary
|
|
142
142
|
params = {
|
|
@@ -160,7 +160,7 @@ def save_parameters(args: argparse.Namespace,
|
|
|
160
160
|
pprint.pprint(params); print()
|
|
161
161
|
|
|
162
162
|
# Save to YAML file
|
|
163
|
-
|
|
163
|
+
io.save_parameters_yaml(params, output_path)
|
|
164
164
|
|
|
165
165
|
if __name__ == "__main__":
|
|
166
166
|
cli()
|
octopi/entry_points/run_train.py
CHANGED
|
@@ -2,12 +2,11 @@ from octopi.datasets import generators, multi_config_generator
|
|
|
2
2
|
from monai.losses import DiceLoss, FocalLoss, TverskyLoss
|
|
3
3
|
from octopi.models import common as builder
|
|
4
4
|
from monai.metrics import ConfusionMatrixMetric
|
|
5
|
+
from octopi.utils import parsers, io
|
|
5
6
|
from octopi.entry_points import common
|
|
6
7
|
from octopi.pytorch import trainer
|
|
7
|
-
from octopi import io, utils
|
|
8
8
|
import torch, os, argparse
|
|
9
9
|
from typing import List, Optional, Tuple
|
|
10
|
-
import pprint
|
|
11
10
|
|
|
12
11
|
def train_model(
|
|
13
12
|
copick_config_path: str,
|
|
@@ -56,7 +55,7 @@ def train_model(
|
|
|
56
55
|
|
|
57
56
|
|
|
58
57
|
# Get the data splits
|
|
59
|
-
ratios =
|
|
58
|
+
ratios = parsers.parse_data_split(data_split)
|
|
60
59
|
data_generator.get_data_splits(
|
|
61
60
|
trainRunIDs = trainRunIDs,
|
|
62
61
|
validateRunIDs = validateRunIDs,
|
|
@@ -114,11 +113,11 @@ def train_model_parser(parser_description, add_slurm: bool = False):
|
|
|
114
113
|
# Input Arguments
|
|
115
114
|
input_group = parser.add_argument_group("Input Arguments")
|
|
116
115
|
common.add_config(input_group, single_config=False)
|
|
117
|
-
input_group.add_argument("--target-info", type=
|
|
116
|
+
input_group.add_argument("--target-info", type=parsers.parse_target, default="targets,octopi,1",
|
|
118
117
|
help="Target information, e.g., 'name' or 'name,user_id,session_id'. Default is 'targets,octopi,1'.")
|
|
119
118
|
input_group.add_argument("--tomo-alg", default='wbp', help="Tomogram algorithm used for training")
|
|
120
|
-
input_group.add_argument("--trainRunIDs", type=
|
|
121
|
-
input_group.add_argument("--validateRunIDs", type=
|
|
119
|
+
input_group.add_argument("--trainRunIDs", type=parsers.parse_list, help="List of training run IDs, e.g., run1,run2,run3")
|
|
120
|
+
input_group.add_argument("--validateRunIDs", type=parsers.parse_list, help="List of validation run IDs, e.g., run4,run5,run6")
|
|
122
121
|
input_group.add_argument('--data-split', type=str, default='0.8', help="Data split ratios. Either a single value (e.g., '0.8' for 80/20/0 split) "
|
|
123
122
|
"or two comma-separated values (e.g., '0.7,0.1' for 70/10/20 split)")
|
|
124
123
|
|
|
@@ -153,11 +152,11 @@ def cli():
|
|
|
153
152
|
args = train_model_parser(parser_description)
|
|
154
153
|
|
|
155
154
|
# Parse the CoPick configuration paths
|
|
156
|
-
if len(args.config) > 1: copick_configs =
|
|
155
|
+
if len(args.config) > 1: copick_configs = parsers.parse_copick_configs(args.config)
|
|
157
156
|
else: copick_configs = args.config[0]
|
|
158
157
|
|
|
159
158
|
if args.model_config:
|
|
160
|
-
model_config =
|
|
159
|
+
model_config = io.load_yaml(args.model_config)
|
|
161
160
|
else:
|
|
162
161
|
model_config = get_model_config(args.channels, args.strides, args.res_units, args.Nclass, args.dim_in)
|
|
163
162
|
|
octopi/extract/localize.py
CHANGED
|
@@ -3,15 +3,15 @@ from scipy.cluster.hierarchy import fcluster, linkage
|
|
|
3
3
|
from skimage.segmentation import watershed
|
|
4
4
|
from typing import List, Optional, Tuple
|
|
5
5
|
from skimage.measure import regionprops
|
|
6
|
+
from copick_utils.io import readers
|
|
6
7
|
from scipy.spatial import distance
|
|
7
8
|
from dataclasses import dataclass
|
|
8
|
-
from octopi import io
|
|
9
9
|
import scipy.ndimage as ndi
|
|
10
10
|
from tqdm import tqdm
|
|
11
11
|
import numpy as np
|
|
12
12
|
import gc
|
|
13
13
|
|
|
14
|
-
def
|
|
14
|
+
def process_localization(run,
|
|
15
15
|
objects,
|
|
16
16
|
seg_info: Tuple[str, str, str],
|
|
17
17
|
method: str = 'com',
|
|
@@ -27,12 +27,12 @@ def processs_localization(run,
|
|
|
27
27
|
raise ValueError(f"Invalid method '{method}'. Expected 'watershed' or 'com'.")
|
|
28
28
|
|
|
29
29
|
# Get Segmentation
|
|
30
|
-
seg =
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
30
|
+
seg = readers.segmentation(
|
|
31
|
+
run, voxel_size,
|
|
32
|
+
seg_info[0],
|
|
33
|
+
user_id=seg_info[1],
|
|
34
|
+
session_id=seg_info[2],
|
|
35
|
+
raise_error=False)
|
|
36
36
|
|
|
37
37
|
# Preprocess Segmentation
|
|
38
38
|
# seg = preprocess_segmentation(seg, voxel_size, objects)
|
|
@@ -99,8 +99,8 @@ def extract_particle_centroids_via_watershed(
|
|
|
99
99
|
max_particle_size (int): Maximum size threshold for particles.
|
|
100
100
|
"""
|
|
101
101
|
|
|
102
|
-
if maxima_filter_size is None or maxima_filter_size
|
|
103
|
-
|
|
102
|
+
if maxima_filter_size is None or maxima_filter_size <= 0:
|
|
103
|
+
raise ValueError('Enter a Non-Zero Filter Size!')
|
|
104
104
|
|
|
105
105
|
# Calculate minimum and maximum particle volumes based on the given radii
|
|
106
106
|
min_particle_size = (4 / 3) * np.pi * (min_particle_radius ** 3)
|
|
@@ -117,12 +117,8 @@ def extract_particle_centroids_via_watershed(
|
|
|
117
117
|
# Structuring element for erosion and dilation
|
|
118
118
|
struct_elem = ball(1)
|
|
119
119
|
eroded = binary_erosion(binary_mask, struct_elem)
|
|
120
|
-
del binary_mask
|
|
121
|
-
gc.collect()
|
|
122
120
|
|
|
123
121
|
dilated = binary_dilation(eroded, struct_elem)
|
|
124
|
-
del eroded
|
|
125
|
-
gc.collect()
|
|
126
122
|
|
|
127
123
|
# Distance transform and local maxima detection
|
|
128
124
|
distance = ndi.distance_transform_edt(dilated)
|
|
@@ -131,12 +127,11 @@ def extract_particle_centroids_via_watershed(
|
|
|
131
127
|
# Watershed segmentation
|
|
132
128
|
markers, _ = ndi.label(local_max)
|
|
133
129
|
del local_max
|
|
134
|
-
markers = markers.astype(np.uint8)
|
|
135
130
|
gc.collect()
|
|
136
131
|
|
|
137
132
|
watershed_labels = watershed(-distance, markers, mask=dilated)
|
|
133
|
+
distance, markers, dilated = None, None, None
|
|
138
134
|
del distance, markers, dilated
|
|
139
|
-
watershed_labels = watershed_labels.astype(np.uint8)
|
|
140
135
|
gc.collect()
|
|
141
136
|
|
|
142
137
|
# Extract region properties and filter based on particle size
|
|
@@ -147,9 +142,6 @@ def extract_particle_centroids_via_watershed(
|
|
|
147
142
|
# Option 1: Use all centroids
|
|
148
143
|
all_centroids.append(region.centroid)
|
|
149
144
|
|
|
150
|
-
del watershed_labels
|
|
151
|
-
gc.collect()
|
|
152
|
-
|
|
153
145
|
return all_centroids
|
|
154
146
|
|
|
155
147
|
def extract_particle_centroids_via_com(
|