octopi 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of octopi might be problematic. Click here for more details.
- octopi/__init__.py +0 -0
- octopi/datasets/__init__.py +0 -0
- octopi/datasets/augment.py +84 -0
- octopi/datasets/cached_datset.py +113 -0
- octopi/datasets/dataset.py +19 -0
- octopi/datasets/generators.py +429 -0
- octopi/datasets/mixup.py +49 -0
- octopi/datasets/multi_config_generator.py +253 -0
- octopi/entry_points/__init__.py +0 -0
- octopi/entry_points/common.py +80 -0
- octopi/entry_points/create_slurm_submission.py +243 -0
- octopi/entry_points/run_create_targets.py +281 -0
- octopi/entry_points/run_evaluate.py +65 -0
- octopi/entry_points/run_extract_mb_picks.py +141 -0
- octopi/entry_points/run_extract_midpoint.py +143 -0
- octopi/entry_points/run_localize.py +222 -0
- octopi/entry_points/run_optuna.py +139 -0
- octopi/entry_points/run_segment_predict.py +166 -0
- octopi/entry_points/run_train.py +201 -0
- octopi/extract/__init__.py +0 -0
- octopi/extract/localize.py +254 -0
- octopi/extract/membranebound_extract.py +262 -0
- octopi/extract/midpoint_extract.py +193 -0
- octopi/io.py +457 -0
- octopi/losses.py +86 -0
- octopi/main.py +101 -0
- octopi/models/AttentionUnet.py +56 -0
- octopi/models/MedNeXt.py +111 -0
- octopi/models/ModelTemplate.py +36 -0
- octopi/models/SegResNet.py +92 -0
- octopi/models/Unet.py +59 -0
- octopi/models/UnetPlusPlus.py +47 -0
- octopi/models/__init__.py +0 -0
- octopi/models/common.py +62 -0
- octopi/processing/__init__.py +0 -0
- octopi/processing/create_targets_from_picks.py +106 -0
- octopi/processing/downsample.py +129 -0
- octopi/processing/evaluate.py +289 -0
- octopi/processing/importers.py +213 -0
- octopi/processing/my_metrics.py +26 -0
- octopi/processing/segmentation_from_picks.py +167 -0
- octopi/processing/writers.py +102 -0
- octopi/pytorch/__init__.py +0 -0
- octopi/pytorch/hyper_search.py +243 -0
- octopi/pytorch/model_search_submitter.py +290 -0
- octopi/pytorch/segmentation.py +317 -0
- octopi/pytorch/trainer.py +438 -0
- octopi/pytorch_lightning/__init__.py +0 -0
- octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
- octopi/pytorch_lightning/train_pl.py +244 -0
- octopi/stopping_criteria.py +143 -0
- octopi/submit_slurm.py +95 -0
- octopi/utils.py +238 -0
- octopi/visualization_tools.py +201 -0
- octopi-1.0.dist-info/LICENSE +41 -0
- octopi-1.0.dist-info/METADATA +209 -0
- octopi-1.0.dist-info/RECORD +59 -0
- octopi-1.0.dist-info/WHEEL +4 -0
- octopi-1.0.dist-info/entry_points.txt +4 -0
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
from octopi.entry_points import common
|
|
2
|
+
from octopi.extract import localize
|
|
3
|
+
from octopi import utils
|
|
4
|
+
import copick, argparse, pprint
|
|
5
|
+
from typing import List, Tuple
|
|
6
|
+
import multiprocess as mp
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
|
|
9
|
+
def pick_particles(
|
|
10
|
+
copick_config_path: str,
|
|
11
|
+
method: str,
|
|
12
|
+
seg_info: Tuple[str, str, str],
|
|
13
|
+
voxel_size: float,
|
|
14
|
+
pick_session_id: str,
|
|
15
|
+
pick_user_id: str,
|
|
16
|
+
radius_min_scale: float,
|
|
17
|
+
radius_max_scale: float,
|
|
18
|
+
filter_size: float,
|
|
19
|
+
pick_objects: List[str],
|
|
20
|
+
runIDs: List[str],
|
|
21
|
+
n_procs: int,
|
|
22
|
+
):
|
|
23
|
+
|
|
24
|
+
# Load the Copick Project
|
|
25
|
+
root = copick.from_file(copick_config_path)
|
|
26
|
+
|
|
27
|
+
# Get objects that can be Picked
|
|
28
|
+
objects = [(obj.name, obj.label, obj.radius) for obj in root.pickable_objects if obj.is_particle]
|
|
29
|
+
|
|
30
|
+
# Verify each object has the required attributes
|
|
31
|
+
for obj in objects:
|
|
32
|
+
if len(obj) < 3 or not isinstance(obj[2], (float, int)):
|
|
33
|
+
raise ValueError(f"Invalid object format: {obj}. Expected a tuple with (name, label, radius).")
|
|
34
|
+
|
|
35
|
+
# Filter elements
|
|
36
|
+
if pick_objects is not None:
|
|
37
|
+
objects = [obj for obj in objects if obj[0] in pick_objects]
|
|
38
|
+
|
|
39
|
+
print(f'Running Localization on the Following Objects: ')
|
|
40
|
+
print(', '.join([f'{obj[0]} (Label: {obj[1]})' for obj in objects]) + '\n')
|
|
41
|
+
|
|
42
|
+
# Either Specify Input RunIDs or Run on All RunIDs
|
|
43
|
+
if runIDs: print('Running Localization on the Following RunIDs: ' + ', '.join(runIDs) + '\n')
|
|
44
|
+
run_ids = runIDs if runIDs else [run.name for run in root.runs]
|
|
45
|
+
n_run_ids = len(run_ids)
|
|
46
|
+
|
|
47
|
+
# Determine the number of processes to use
|
|
48
|
+
if n_procs is None:
|
|
49
|
+
n_procs = min(int(mp.cpu_count()//4), n_run_ids)
|
|
50
|
+
print(f"Using {n_procs} processes to parallelize across {n_run_ids} run IDs.")
|
|
51
|
+
|
|
52
|
+
# Initialize tqdm progress bar
|
|
53
|
+
with tqdm(total=n_run_ids, desc="Localization", unit="run") as pbar:
|
|
54
|
+
for _iz in range(0, n_run_ids, n_procs):
|
|
55
|
+
|
|
56
|
+
start_idx = _iz
|
|
57
|
+
end_idx = min(_iz + n_procs, n_run_ids) # Ensure end_idx does not exceed n_run_ids
|
|
58
|
+
print(f"\nProcessing runIDs from {start_idx} -> {end_idx } (out of {n_run_ids})")
|
|
59
|
+
|
|
60
|
+
processes = []
|
|
61
|
+
for _in in range(n_procs):
|
|
62
|
+
_iz_this = _iz + _in
|
|
63
|
+
if _iz_this >= n_run_ids:
|
|
64
|
+
break
|
|
65
|
+
run_id = run_ids[_iz_this]
|
|
66
|
+
run = root.get_run(run_id)
|
|
67
|
+
p = mp.Process(
|
|
68
|
+
target=localize.processs_localization,
|
|
69
|
+
args=(run,
|
|
70
|
+
objects,
|
|
71
|
+
seg_info,
|
|
72
|
+
method,
|
|
73
|
+
voxel_size,
|
|
74
|
+
filter_size,
|
|
75
|
+
radius_min_scale,
|
|
76
|
+
radius_max_scale,
|
|
77
|
+
pick_session_id,
|
|
78
|
+
pick_user_id),
|
|
79
|
+
)
|
|
80
|
+
processes.append(p)
|
|
81
|
+
|
|
82
|
+
for p in processes:
|
|
83
|
+
p.start()
|
|
84
|
+
|
|
85
|
+
for p in processes:
|
|
86
|
+
p.join()
|
|
87
|
+
|
|
88
|
+
for p in processes:
|
|
89
|
+
p.close()
|
|
90
|
+
|
|
91
|
+
# Update tqdm progress bar
|
|
92
|
+
pbar.update(len(processes))
|
|
93
|
+
|
|
94
|
+
print('Localization Complete!')
|
|
95
|
+
|
|
96
|
+
def localize_parser(parser_description, add_slurm: bool = False):
|
|
97
|
+
parser = argparse.ArgumentParser(
|
|
98
|
+
description=parser_description,
|
|
99
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
100
|
+
)
|
|
101
|
+
input_group = parser.add_argument_group("Input Arguments")
|
|
102
|
+
input_group.add_argument("--config", type=str, required=True, help="Path to the CoPick configuration file.")
|
|
103
|
+
input_group.add_argument("--method", type=str, choices=['watershed', 'com'], default='watershed', required=False, help="Localization method to use.")
|
|
104
|
+
input_group.add_argument('--seg-info', type=utils.parse_target, required=True, help='Query for the organelles segmentations (e.g., "name" or "name,user_id,session_id".).')
|
|
105
|
+
input_group.add_argument("--voxel-size", type=float, default=10, required=False, help="Voxel size for localization.")
|
|
106
|
+
input_group.add_argument("--runIDs", type=utils.parse_list, default = None, required=False, help="List of runIDs to run inference on, e.g., run1,run2,run3 or [run1,run2,run3].")
|
|
107
|
+
|
|
108
|
+
localize_group = parser.add_argument_group("Localize Arguments")
|
|
109
|
+
localize_group.add_argument("--radius-min-scale", type=float, default=0.5, required=False, help="Minimum radius scale for particles.")
|
|
110
|
+
localize_group.add_argument("--radius-max-scale", type=float, default=1.0, required=False, help="Maximum radius scale for particles.")
|
|
111
|
+
localize_group.add_argument("--filter-size", type=int, default=10, required=False, help="Filter size for localization.")
|
|
112
|
+
localize_group.add_argument("--pick-objects", type=utils.parse_list, default=None, required=False, help="Specific Objects to Find Picks for.")
|
|
113
|
+
localize_group.add_argument("--n-procs", type=int, default=None, required=False, help="Number of CPU processes to parallelize runs across. Defaults to the max number of cores available or available runs.")
|
|
114
|
+
|
|
115
|
+
output_group = parser.add_argument_group("Output Arguments")
|
|
116
|
+
output_group.add_argument("--pick-session-id", type=str, default='1', required=False, help="Session ID for the particle picks.")
|
|
117
|
+
output_group.add_argument("--pick-user-id", type=str, default='monai', required=False, help="User ID for the particle picks.")
|
|
118
|
+
|
|
119
|
+
if add_slurm:
|
|
120
|
+
slurm_group = parser.add_argument_group("SLURM Arguments")
|
|
121
|
+
common.add_slurm_parameters(slurm_group, 'localize', gpus = 0)
|
|
122
|
+
|
|
123
|
+
args = parser.parse_args()
|
|
124
|
+
return args
|
|
125
|
+
|
|
126
|
+
# Entry point with argparse
|
|
127
|
+
def cli():
|
|
128
|
+
|
|
129
|
+
parser_description = "Localized particles in tomograms using multiprocessing."
|
|
130
|
+
args = localize_parser(parser_description)
|
|
131
|
+
|
|
132
|
+
# Save JSON with Parameters
|
|
133
|
+
output_yaml = f'localize_{args.pick_user_id}_{args.pick_session_id}.yaml'
|
|
134
|
+
save_parameters(args, output_yaml)
|
|
135
|
+
|
|
136
|
+
# Set multiprocessing start method
|
|
137
|
+
mp.set_start_method("spawn")
|
|
138
|
+
|
|
139
|
+
pick_particles(
|
|
140
|
+
copick_config_path=args.config,
|
|
141
|
+
method=args.method,
|
|
142
|
+
seg_info=args.seg_info,
|
|
143
|
+
voxel_size=args.voxel_size,
|
|
144
|
+
pick_session_id=args.pick_session_id,
|
|
145
|
+
pick_user_id=args.pick_user_id,
|
|
146
|
+
radius_min_scale=args.radius_min_scale,
|
|
147
|
+
radius_max_scale=args.radius_max_scale,
|
|
148
|
+
filter_size=args.filter_size,
|
|
149
|
+
runIDs=args.runIDs,
|
|
150
|
+
pick_objects=args.pick_objects,
|
|
151
|
+
n_procs=args.n_procs,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
def save_parameters(args: argparse.Namespace,
|
|
155
|
+
output_path: str):
|
|
156
|
+
|
|
157
|
+
# Organize parameters into categories
|
|
158
|
+
params = {
|
|
159
|
+
"input": {
|
|
160
|
+
"config": args.config,
|
|
161
|
+
"seg_name": args.seg_info[0],
|
|
162
|
+
"seg_user_id": args.seg_info[1],
|
|
163
|
+
"seg_session_id": args.seg_info[2],
|
|
164
|
+
"voxel_size": args.voxel_size
|
|
165
|
+
},
|
|
166
|
+
"output": {
|
|
167
|
+
"pick_session_id": args.pick_session_id,
|
|
168
|
+
"pick_user_id": args.pick_user_id
|
|
169
|
+
},
|
|
170
|
+
"parameters": {
|
|
171
|
+
"method": args.method,
|
|
172
|
+
"radius_min_scale": args.radius_min_scale,
|
|
173
|
+
"radius_max_scale": args.radius_max_scale,
|
|
174
|
+
"filter_size": args.filter_size,
|
|
175
|
+
"runIDs": args.runIDs
|
|
176
|
+
}
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
# Print the parameters
|
|
180
|
+
print(f"\nParameters for Localization:")
|
|
181
|
+
pprint.pprint(params); print()
|
|
182
|
+
|
|
183
|
+
# Save to YAML file
|
|
184
|
+
utils.save_parameters_yaml(params, output_path)
|
|
185
|
+
|
|
186
|
+
if __name__ == "__main__":
|
|
187
|
+
cli()
|
|
188
|
+
|
|
189
|
+
# def time_pick_particles():
|
|
190
|
+
# import json, time
|
|
191
|
+
|
|
192
|
+
# # Set multiprocessing start method
|
|
193
|
+
# mp.set_start_method("spawn")
|
|
194
|
+
|
|
195
|
+
# copick_config_path = "/mnt/simulations/ml_challenge/ml_config.json" # Replace with your actual path
|
|
196
|
+
# n_procs_list = [1, 4, 8, 16, 32] # Adjust based on your needs
|
|
197
|
+
# n_procs_list = [32, 16, 8, 4, 1]
|
|
198
|
+
# timing_results = {}
|
|
199
|
+
|
|
200
|
+
# session_id = 1
|
|
201
|
+
# for n_procs in n_procs_list:
|
|
202
|
+
# print(f"Testing with {n_procs} processes...")
|
|
203
|
+
# start_time = time.time()
|
|
204
|
+
# pick_particles(
|
|
205
|
+
# copick_config_path=copick_config_path,
|
|
206
|
+
# pick_session_id=str(session_id),
|
|
207
|
+
# n_procs=n_procs
|
|
208
|
+
# )
|
|
209
|
+
# elapsed_time = time.time() - start_time
|
|
210
|
+
# timing_results[n_procs] = elapsed_time
|
|
211
|
+
# print(f"Elapsed time with {n_procs} processes: {elapsed_time:.2f} seconds")
|
|
212
|
+
|
|
213
|
+
# session_id +=1
|
|
214
|
+
|
|
215
|
+
# # Save timing results to a JSON file
|
|
216
|
+
# with open("timing_results.json", "w") as f:
|
|
217
|
+
# json.dump(timing_results, f, indent=4)
|
|
218
|
+
|
|
219
|
+
# print("Timing results saved to 'timing_results.json'")
|
|
220
|
+
|
|
221
|
+
# if __name__ == "__main__":
|
|
222
|
+
# time_pick_particles()
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
from octopi.pytorch.model_search_submitter import ModelSearchSubmit
|
|
2
|
+
from octopi.entry_points import common
|
|
3
|
+
import argparse, os, pprint
|
|
4
|
+
from octopi import utils
|
|
5
|
+
|
|
6
|
+
def optuna_parser(parser_description, add_slurm: bool = False):
|
|
7
|
+
"""
|
|
8
|
+
Create an argument parser for model architecture search using Optuna.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
parser_description (str): Description of the parser
|
|
12
|
+
add_slurm (bool): Whether to add SLURM-specific arguments
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
parser = argparse.ArgumentParser(
|
|
16
|
+
description=parser_description,
|
|
17
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
# Input Arguments
|
|
21
|
+
input_group = parser.add_argument_group("Input Arguments")
|
|
22
|
+
common.add_config(input_group, single_config=False)
|
|
23
|
+
input_group.add_argument("--target-info", type=utils.parse_target, default="targets,octopi,1",
|
|
24
|
+
help="Target information, e.g., 'name' or 'name,user_id,session_id'")
|
|
25
|
+
input_group.add_argument("--tomo-alg", default='wbp',
|
|
26
|
+
help="Tomogram algorithm used for training")
|
|
27
|
+
input_group.add_argument("--mlflow-experiment-name", type=str, default="model-search", required=False,
|
|
28
|
+
help="Name of the MLflow experiment (default: 'model-search').")
|
|
29
|
+
input_group.add_argument("--trainRunIDs", type=utils.parse_list, default=None, required=False,
|
|
30
|
+
help="List of training run IDs, e.g., run1,run2 or [run1,run2].")
|
|
31
|
+
input_group.add_argument("--validateRunIDs", type=utils.parse_list, default=None, required=False,
|
|
32
|
+
help="List of validation run IDs, e.g., run3,run4 or [run3,run4].")
|
|
33
|
+
input_group.add_argument('--data-split', type=str, default='0.8', help="Data split ratios. Either a single value (e.g., '0.8' for 80/20/0 split) "
|
|
34
|
+
"or two comma-separated values (e.g., '0.7,0.1' for 70/10/20 split)")
|
|
35
|
+
|
|
36
|
+
model_group = parser.add_argument_group("Model Arguments")
|
|
37
|
+
model_group.add_argument("--model-type", type=str, default='Unet', required=False,
|
|
38
|
+
choices=['Unet', 'AttentionUnet'],
|
|
39
|
+
help="Model type to use for training")
|
|
40
|
+
model_group.add_argument("--Nclass", type=int, default=3, required=False, help="Number of prediction classes in the model")
|
|
41
|
+
|
|
42
|
+
train_group = parser.add_argument_group("Training Arguments")
|
|
43
|
+
common.add_train_parameters(train_group, octopi = True)
|
|
44
|
+
train_group.add_argument("--random-seed", type=int, default=42, required=False,
|
|
45
|
+
help="Random seed for reproducibility (default: 42).")
|
|
46
|
+
|
|
47
|
+
if add_slurm:
|
|
48
|
+
slurm_group = parser.add_argument_group("SLURM Arguments")
|
|
49
|
+
common.add_slurm_parameters(slurm_group, 'optuna')
|
|
50
|
+
|
|
51
|
+
args = parser.parse_args()
|
|
52
|
+
return args
|
|
53
|
+
|
|
54
|
+
# Entry point with argparse
|
|
55
|
+
def cli():
|
|
56
|
+
"""
|
|
57
|
+
CLI entry point for running optuna model archetecture search.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
description="Perform model architecture search with Optuna and MLflow integration."
|
|
61
|
+
args = optuna_parser(description)
|
|
62
|
+
|
|
63
|
+
# Parse the CoPick configuration paths
|
|
64
|
+
if len(args.config) > 1: copick_configs = utils.parse_copick_configs(args.config)
|
|
65
|
+
else: copick_configs = args.config[0]
|
|
66
|
+
|
|
67
|
+
# Create the model exploration directory
|
|
68
|
+
os.makedirs(f'explore_results_{args.model_type}', exist_ok=True)
|
|
69
|
+
|
|
70
|
+
# Save JSON with Parameters
|
|
71
|
+
save_parameters(args, f'explore_results_{args.model_type}/octopi.yaml')
|
|
72
|
+
|
|
73
|
+
# Call the function with parsed arguments
|
|
74
|
+
search = ModelSearchSubmit(
|
|
75
|
+
copick_config=copick_configs,
|
|
76
|
+
target_name=args.target_info[0],
|
|
77
|
+
target_user_id=args.target_info[1],
|
|
78
|
+
target_session_id=args.target_info[2],
|
|
79
|
+
tomo_algorithm=args.tomo_alg,
|
|
80
|
+
voxel_size=args.voxel_size,
|
|
81
|
+
Nclass=args.Nclass,
|
|
82
|
+
model_type=args.model_type,
|
|
83
|
+
mlflow_experiment_name=args.mlflow_experiment_name,
|
|
84
|
+
random_seed=args.random_seed,
|
|
85
|
+
num_epochs=args.num_epochs,
|
|
86
|
+
num_trials=args.num_trials,
|
|
87
|
+
trainRunIDs=args.trainRunIDs,
|
|
88
|
+
validateRunIDs=args.validateRunIDs,
|
|
89
|
+
tomo_batch_size=args.tomo_batch_size,
|
|
90
|
+
best_metric=args.best_metric,
|
|
91
|
+
val_interval=args.val_interval,
|
|
92
|
+
data_split=args.data_split
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# Run the model search
|
|
96
|
+
search.run_model_search()
|
|
97
|
+
|
|
98
|
+
def save_parameters(args: argparse.Namespace,
|
|
99
|
+
output_path: str):
|
|
100
|
+
"""
|
|
101
|
+
Save the Optuna search parameters to a JSON file.
|
|
102
|
+
Args:
|
|
103
|
+
args: Parsed arguments from argparse.
|
|
104
|
+
output_path: Path to save the JSON file.
|
|
105
|
+
"""
|
|
106
|
+
# Organize parameters into categories
|
|
107
|
+
params = {
|
|
108
|
+
"input": {
|
|
109
|
+
"copick_config": args.config,
|
|
110
|
+
"target_info": args.target_info,
|
|
111
|
+
"tomo_algorithm": args.tomo_alg,
|
|
112
|
+
"voxel_size": args.voxel_size,
|
|
113
|
+
"Nclass": args.Nclass,
|
|
114
|
+
},
|
|
115
|
+
"optimization": {
|
|
116
|
+
"model_type": args.model_type,
|
|
117
|
+
"mlflow_experiment_name": args.mlflow_experiment_name,
|
|
118
|
+
"random_seed": args.random_seed,
|
|
119
|
+
"num_trials": args.num_trials,
|
|
120
|
+
"best_metric": args.best_metric
|
|
121
|
+
},
|
|
122
|
+
"training": {
|
|
123
|
+
"num_epochs": args.num_epochs,
|
|
124
|
+
"tomo_batch_size": args.tomo_batch_size,
|
|
125
|
+
"trainRunIDs": args.trainRunIDs,
|
|
126
|
+
"validateRunIDs": args.validateRunIDs,
|
|
127
|
+
"data_split": args.data_split
|
|
128
|
+
}
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
# Print the parameters
|
|
132
|
+
print(f"\nParameters for Model Architecture Search:")
|
|
133
|
+
pprint.pprint(params); print()
|
|
134
|
+
|
|
135
|
+
# Save to YAML file
|
|
136
|
+
utils.save_parameters_yaml(params, output_path)
|
|
137
|
+
|
|
138
|
+
if __name__ == "__main__":
|
|
139
|
+
cli()
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
from octopi.pytorch import segmentation
|
|
2
|
+
from octopi.entry_points import common
|
|
3
|
+
import torch, argparse, json, pprint, yaml, os
|
|
4
|
+
from octopi import utils
|
|
5
|
+
from typing import List, Tuple
|
|
6
|
+
|
|
7
|
+
def inference(
|
|
8
|
+
copick_config_path: str,
|
|
9
|
+
model_weights: str,
|
|
10
|
+
model_config: str,
|
|
11
|
+
seg_info: Tuple[str,str,str],
|
|
12
|
+
voxel_size: float,
|
|
13
|
+
tomo_algorithm: str,
|
|
14
|
+
tomo_batch_size: int,
|
|
15
|
+
run_ids: List[str],
|
|
16
|
+
):
|
|
17
|
+
"""
|
|
18
|
+
Perform segmentation inference using a model on provided tomograms.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
copick_config_path (str): Path to CoPick configuration file.
|
|
22
|
+
run_ids (List[str]): List of tomogram run IDs for inference.
|
|
23
|
+
model_weights (str): Path to the trained model weights file.
|
|
24
|
+
channels (List[int]): List of channel sizes for each layer.
|
|
25
|
+
strides (List[int]): List of strides for the layers.
|
|
26
|
+
res_units (int): Number of residual units for the model.
|
|
27
|
+
voxel_size (float): Voxel size for tomogram reconstruction.
|
|
28
|
+
tomo_algorithm (str): Tomogram reconstruction algorithm to use.
|
|
29
|
+
segmentation_name (str): Name for the segmentation output.
|
|
30
|
+
segmentation_user_id (str): User ID associated with the segmentation.
|
|
31
|
+
segmentation_session_id (str): Session ID for this segmentation run.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
gpu_count = torch.cuda.device_count()
|
|
35
|
+
print(f"Number of GPUs available: {gpu_count}")
|
|
36
|
+
|
|
37
|
+
if gpu_count > 1:
|
|
38
|
+
print("Using Multi-GPU Predictor.")
|
|
39
|
+
predict = segmentation.MultiGPUPredictor(
|
|
40
|
+
copick_config_path,
|
|
41
|
+
model_config,
|
|
42
|
+
model_weights
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# Run Multi-GPU inference
|
|
46
|
+
predict.multi_gpu_inference(
|
|
47
|
+
runIDs=run_ids,
|
|
48
|
+
tomo_algorithm=tomo_algorithm,
|
|
49
|
+
voxel_spacing=voxel_size,
|
|
50
|
+
segmentation_name=seg_info[0],
|
|
51
|
+
segmentation_user_id=seg_info[1],
|
|
52
|
+
segmentation_session_id=seg_info[2],
|
|
53
|
+
save=True
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
else:
|
|
57
|
+
print("Using Single-GPU Predictor.")
|
|
58
|
+
predict = segmentation.Predictor(
|
|
59
|
+
copick_config_path,
|
|
60
|
+
model_config,
|
|
61
|
+
model_weights,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Run batch prediction
|
|
65
|
+
predict.batch_predict(
|
|
66
|
+
runIDs=run_ids,
|
|
67
|
+
num_tomos_per_batch=tomo_batch_size,
|
|
68
|
+
tomo_algorithm=tomo_algorithm,
|
|
69
|
+
voxel_spacing=voxel_size,
|
|
70
|
+
segmentation_name=seg_info[0],
|
|
71
|
+
segmentation_user_id=seg_info[1],
|
|
72
|
+
segmentation_session_id=seg_info[2]
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
print("Inference completed successfully.")
|
|
76
|
+
|
|
77
|
+
def inference_parser(parser_description, add_slurm: bool = False):
|
|
78
|
+
"""
|
|
79
|
+
Parse the arguments for the inference
|
|
80
|
+
"""
|
|
81
|
+
parser = argparse.ArgumentParser(
|
|
82
|
+
description=parser_description,
|
|
83
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
84
|
+
)
|
|
85
|
+
input_group = parser.add_argument_group("Input Arguments")
|
|
86
|
+
common.add_config(input_group, single_config=True)
|
|
87
|
+
|
|
88
|
+
model_group = parser.add_argument_group("Model Arguments")
|
|
89
|
+
common.inference_model_parameters(model_group)
|
|
90
|
+
|
|
91
|
+
inference_group = parser.add_argument_group("Inference Arguments")
|
|
92
|
+
common.add_inference_parameters(inference_group)
|
|
93
|
+
|
|
94
|
+
if add_slurm:
|
|
95
|
+
slurm_group = parser.add_argument_group("SLURM Arguments")
|
|
96
|
+
common.add_slurm_parameters(slurm_group, 'segment_predict', gpus = 2)
|
|
97
|
+
|
|
98
|
+
args = parser.parse_args()
|
|
99
|
+
return args
|
|
100
|
+
|
|
101
|
+
# Entry point with argparse
|
|
102
|
+
def cli():
|
|
103
|
+
"""
|
|
104
|
+
CLI entry point for running inference.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
# Parse the arguments
|
|
108
|
+
parser_description = "Run segmentation predictions with a specified model and configuration on CryoET Tomograms."
|
|
109
|
+
args = inference_parser(parser_description)
|
|
110
|
+
|
|
111
|
+
# Set default values if not provided
|
|
112
|
+
args.seg_info = list(args.seg_info) # Convert tuple to list
|
|
113
|
+
if args.seg_info[1] is None:
|
|
114
|
+
args.seg_info[1] = "octopi"
|
|
115
|
+
|
|
116
|
+
if args.seg_info[2] is None:
|
|
117
|
+
args.seg_info[2] = "1"
|
|
118
|
+
|
|
119
|
+
# Save JSON with Parameters
|
|
120
|
+
output_json = f'segment-predict_{args.seg_info[1]}_{args.seg_info[2]}_{args.seg_info[0]}.yaml'
|
|
121
|
+
save_parameters(args, output_json)
|
|
122
|
+
|
|
123
|
+
# Call the inference function with parsed arguments
|
|
124
|
+
inference(
|
|
125
|
+
copick_config_path=args.config,
|
|
126
|
+
model_weights=args.model_weights,
|
|
127
|
+
model_config=args.model_config,
|
|
128
|
+
seg_info=args.seg_info,
|
|
129
|
+
voxel_size=args.voxel_size,
|
|
130
|
+
tomo_algorithm=args.tomo_alg,
|
|
131
|
+
tomo_batch_size=args.tomo_batch_size,
|
|
132
|
+
run_ids=args.run_ids,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
def save_parameters(args: argparse.Namespace,
|
|
136
|
+
output_path: str):
|
|
137
|
+
|
|
138
|
+
# Load the model config
|
|
139
|
+
model_config = utils.load_yaml(args.model_config)
|
|
140
|
+
|
|
141
|
+
# Create parameters dictionary
|
|
142
|
+
params = {
|
|
143
|
+
"inputs": {
|
|
144
|
+
"config": args.config,
|
|
145
|
+
"model_config": args.model_config,
|
|
146
|
+
"model_weights": args.model_weights,
|
|
147
|
+
"tomo_algorithm": args.tomo_alg,
|
|
148
|
+
"voxel_size": args.voxel_size
|
|
149
|
+
},
|
|
150
|
+
"outputs": {
|
|
151
|
+
"segmentation_name": args.seg_info[0],
|
|
152
|
+
"segmentation_user_id": args.seg_info[1],
|
|
153
|
+
"segmentation_session_id": args.seg_info[2]
|
|
154
|
+
},
|
|
155
|
+
'model': model_config['model']
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
# Print the parameters
|
|
159
|
+
print(f"\nParameters for Inference (Segment Prediction):")
|
|
160
|
+
pprint.pprint(params); print()
|
|
161
|
+
|
|
162
|
+
# Save to YAML file
|
|
163
|
+
utils.save_parameters_yaml(params, output_path)
|
|
164
|
+
|
|
165
|
+
if __name__ == "__main__":
|
|
166
|
+
cli()
|