octopi 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of octopi might be problematic. Click here for more details.
- octopi/__init__.py +0 -0
- octopi/datasets/__init__.py +0 -0
- octopi/datasets/augment.py +84 -0
- octopi/datasets/cached_datset.py +113 -0
- octopi/datasets/dataset.py +19 -0
- octopi/datasets/generators.py +429 -0
- octopi/datasets/mixup.py +49 -0
- octopi/datasets/multi_config_generator.py +253 -0
- octopi/entry_points/__init__.py +0 -0
- octopi/entry_points/common.py +80 -0
- octopi/entry_points/create_slurm_submission.py +243 -0
- octopi/entry_points/run_create_targets.py +281 -0
- octopi/entry_points/run_evaluate.py +65 -0
- octopi/entry_points/run_extract_mb_picks.py +141 -0
- octopi/entry_points/run_extract_midpoint.py +143 -0
- octopi/entry_points/run_localize.py +222 -0
- octopi/entry_points/run_optuna.py +139 -0
- octopi/entry_points/run_segment_predict.py +166 -0
- octopi/entry_points/run_train.py +201 -0
- octopi/extract/__init__.py +0 -0
- octopi/extract/localize.py +254 -0
- octopi/extract/membranebound_extract.py +262 -0
- octopi/extract/midpoint_extract.py +193 -0
- octopi/io.py +457 -0
- octopi/losses.py +86 -0
- octopi/main.py +101 -0
- octopi/models/AttentionUnet.py +56 -0
- octopi/models/MedNeXt.py +111 -0
- octopi/models/ModelTemplate.py +36 -0
- octopi/models/SegResNet.py +92 -0
- octopi/models/Unet.py +59 -0
- octopi/models/UnetPlusPlus.py +47 -0
- octopi/models/__init__.py +0 -0
- octopi/models/common.py +62 -0
- octopi/processing/__init__.py +0 -0
- octopi/processing/create_targets_from_picks.py +106 -0
- octopi/processing/downsample.py +129 -0
- octopi/processing/evaluate.py +289 -0
- octopi/processing/importers.py +213 -0
- octopi/processing/my_metrics.py +26 -0
- octopi/processing/segmentation_from_picks.py +167 -0
- octopi/processing/writers.py +102 -0
- octopi/pytorch/__init__.py +0 -0
- octopi/pytorch/hyper_search.py +243 -0
- octopi/pytorch/model_search_submitter.py +290 -0
- octopi/pytorch/segmentation.py +317 -0
- octopi/pytorch/trainer.py +438 -0
- octopi/pytorch_lightning/__init__.py +0 -0
- octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
- octopi/pytorch_lightning/train_pl.py +244 -0
- octopi/stopping_criteria.py +143 -0
- octopi/submit_slurm.py +95 -0
- octopi/utils.py +238 -0
- octopi/visualization_tools.py +201 -0
- octopi-1.0.dist-info/LICENSE +41 -0
- octopi-1.0.dist-info/METADATA +209 -0
- octopi-1.0.dist-info/RECORD +59 -0
- octopi-1.0.dist-info/WHEEL +4 -0
- octopi-1.0.dist-info/entry_points.txt +4 -0
octopi/datasets/mixup.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from monai.transforms import Transform
|
|
2
|
+
from torch.distributions import Beta
|
|
3
|
+
from torch import nn
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
class MixupTransformd(Transform):
|
|
8
|
+
"""
|
|
9
|
+
A dictionary-based wrapper for Mixup augmentation that applies to batches.
|
|
10
|
+
This needs to be applied after batching, typically directly in the training loop.
|
|
11
|
+
"""
|
|
12
|
+
def __init__(self, keys=("image", "label"), mix_beta=0.2, mixadd=False, prob=0.5):
|
|
13
|
+
self.keys = keys
|
|
14
|
+
self.mixup = Mixup(mix_beta=mix_beta, mixadd=mixadd)
|
|
15
|
+
self.prob = prob
|
|
16
|
+
|
|
17
|
+
def __call__(self, data):
|
|
18
|
+
d = dict(data)
|
|
19
|
+
if np.random.random() < self.prob: # Apply with probability
|
|
20
|
+
d[self.keys[0]], d[self.keys[1]] = self.mixup(d[self.keys[0]], d[self.keys[1]])
|
|
21
|
+
return d
|
|
22
|
+
|
|
23
|
+
class Mixup(nn.Module):
|
|
24
|
+
def __init__(self, mix_beta, mixadd=False):
|
|
25
|
+
|
|
26
|
+
super(Mixup, self).__init__()
|
|
27
|
+
self.beta_distribution = Beta(mix_beta, mix_beta)
|
|
28
|
+
self.mixadd = mixadd
|
|
29
|
+
|
|
30
|
+
def forward(self, X, Y, Z=None):
|
|
31
|
+
|
|
32
|
+
bs = X.shape[0]
|
|
33
|
+
n_dims = len(X.shape)
|
|
34
|
+
perm = torch.randperm(bs)
|
|
35
|
+
coeffs = self.beta_distribution.rsample(torch.Size((bs,))).to(X.device)
|
|
36
|
+
X_coeffs = coeffs.view((-1,) + (1,)*(X.ndim-1))
|
|
37
|
+
Y_coeffs = coeffs.view((-1,) + (1,)*(Y.ndim-1))
|
|
38
|
+
|
|
39
|
+
X = X_coeffs * X + (1-X_coeffs) * X[perm]
|
|
40
|
+
|
|
41
|
+
if self.mixadd:
|
|
42
|
+
Y = (Y + Y[perm]).clip(0, 1)
|
|
43
|
+
else:
|
|
44
|
+
Y = Y_coeffs * Y + (1 - Y_coeffs) * Y[perm]
|
|
45
|
+
|
|
46
|
+
if Z:
|
|
47
|
+
return X, Y, Z
|
|
48
|
+
|
|
49
|
+
return X, Y
|
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
from octopi.datasets import dataset, augment, cached_datset
|
|
2
|
+
from octopi.datasets.generators import TrainLoaderManager
|
|
3
|
+
from monai.data import DataLoader, SmartCacheDataset, CacheDataset, Dataset
|
|
4
|
+
from octopi import io
|
|
5
|
+
import multiprocess as mp
|
|
6
|
+
from typing import List
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
import torch, gc
|
|
9
|
+
|
|
10
|
+
class MultiConfigTrainLoaderManager(TrainLoaderManager):
|
|
11
|
+
|
|
12
|
+
def __init__(self,
|
|
13
|
+
configs: dict, # Dictionary of session names and config paths
|
|
14
|
+
target_name: str,
|
|
15
|
+
target_session_id: str = None,
|
|
16
|
+
target_user_id: str = None,
|
|
17
|
+
voxel_size: float = 10,
|
|
18
|
+
tomo_algorithm: List[str] = ['wbp'],
|
|
19
|
+
tomo_batch_size: int = 15,
|
|
20
|
+
Nclasses: int = 3):
|
|
21
|
+
"""
|
|
22
|
+
Initialize MultiConfigTrainLoaderManager with multiple configs.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
configs (list): List of config file paths.
|
|
26
|
+
Other arguments are inherited from TrainLoaderManager.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
# Initialize shared attributes manually (skip super().__init__ to avoid invalid config handling)
|
|
30
|
+
self.config = configs
|
|
31
|
+
self.roots = {name: io.load_copick_config(path) for name, path in configs.items()}
|
|
32
|
+
|
|
33
|
+
# Target and algorithm parameters
|
|
34
|
+
self.target_name = target_name
|
|
35
|
+
self.target_session_id = target_session_id
|
|
36
|
+
self.target_user_id = target_user_id
|
|
37
|
+
self.voxel_size = voxel_size
|
|
38
|
+
self.tomo_algorithm = tomo_algorithm
|
|
39
|
+
|
|
40
|
+
# Data management parameters
|
|
41
|
+
self.Nclasses = Nclasses
|
|
42
|
+
self.tomo_batch_size = tomo_batch_size
|
|
43
|
+
self.reload_training_dataset = True
|
|
44
|
+
self.reload_validation_dataset = True
|
|
45
|
+
self.val_loader = None
|
|
46
|
+
self.train_loader = None
|
|
47
|
+
|
|
48
|
+
# Initialize Run IDs placeholder
|
|
49
|
+
self.myRunIDs = {}
|
|
50
|
+
|
|
51
|
+
# Initialize the input dimensions
|
|
52
|
+
self.nx = None
|
|
53
|
+
self.ny = None
|
|
54
|
+
self.nz = None
|
|
55
|
+
|
|
56
|
+
def get_available_runIDs(self):
|
|
57
|
+
"""
|
|
58
|
+
Identify and return a combined list of run IDs with available segmentations
|
|
59
|
+
across all configured CoPick projects.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
List of tuples: [(session_name, run_name), ...]
|
|
63
|
+
"""
|
|
64
|
+
available_runIDs = []
|
|
65
|
+
for name, root in self.roots.items():
|
|
66
|
+
runIDs = [run.name for run in root.runs]
|
|
67
|
+
for run in runIDs:
|
|
68
|
+
run = root.get_run(run)
|
|
69
|
+
seg = run.get_segmentations(
|
|
70
|
+
name=self.target_name,
|
|
71
|
+
session_id=self.target_session_id,
|
|
72
|
+
user_id=self.target_user_id,
|
|
73
|
+
voxel_size=float(self.voxel_size)
|
|
74
|
+
)
|
|
75
|
+
if len(seg) > 0:
|
|
76
|
+
available_runIDs.append((name, run.name)) # Include session name for disambiguation
|
|
77
|
+
|
|
78
|
+
# If No Segmentations are Found, Inform the User
|
|
79
|
+
if len(available_runIDs) == 0:
|
|
80
|
+
print(
|
|
81
|
+
f"[Error] No segmentations found for the target query:\n"
|
|
82
|
+
f"TargetName: {self.target_name}, UserID: {self.target_user_id}, "
|
|
83
|
+
f"SessionID: {self.target_session_id}\n"
|
|
84
|
+
f"Please check the target name, user ID, and session ID.\n"
|
|
85
|
+
)
|
|
86
|
+
exit()
|
|
87
|
+
|
|
88
|
+
return available_runIDs
|
|
89
|
+
|
|
90
|
+
def get_data_splits(self,
|
|
91
|
+
trainRunIDs: str = None,
|
|
92
|
+
validateRunIDs: str = None,
|
|
93
|
+
train_ratio: float = 0.8,
|
|
94
|
+
val_ratio: float = 0.1,
|
|
95
|
+
test_ratio: float = 0.1,
|
|
96
|
+
create_test_dataset: bool = True):
|
|
97
|
+
"""
|
|
98
|
+
Override to handle run IDs as (session_name, run_name) tuples.
|
|
99
|
+
"""
|
|
100
|
+
# Use the get_available_runIDs method to handle multiple projects
|
|
101
|
+
runIDs = self.get_available_runIDs()
|
|
102
|
+
return super().get_data_splits(trainRunIDs = runIDs,
|
|
103
|
+
train_ratio = train_ratio,
|
|
104
|
+
val_ratio = val_ratio,
|
|
105
|
+
test_ratio = test_ratio,
|
|
106
|
+
create_test_dataset = create_test_dataset)
|
|
107
|
+
|
|
108
|
+
def _initialize_train_iterators(self):
|
|
109
|
+
"""
|
|
110
|
+
Initialize the training data iterators with multi-config support.
|
|
111
|
+
"""
|
|
112
|
+
self.padded_train_list = self._get_padded_list(self.myRunIDs['train'], self.train_batch_size)
|
|
113
|
+
self.train_data_iter = iter(self._get_data_batches(self.padded_train_list, self.train_batch_size))
|
|
114
|
+
|
|
115
|
+
def _initialize_val_iterators(self):
|
|
116
|
+
"""
|
|
117
|
+
Initialize the validation data iterators with multi-config support.
|
|
118
|
+
"""
|
|
119
|
+
self.padded_val_list = self._get_padded_list(self.myRunIDs['validate'], self.val_batch_size)
|
|
120
|
+
self.val_data_iter = iter(self._get_data_batches(self.padded_val_list, self.val_batch_size))
|
|
121
|
+
|
|
122
|
+
def _load_data(self, runIDs):
|
|
123
|
+
"""
|
|
124
|
+
Load data from multiple CoPick projects for given run IDs.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
runIDs (list): List of (session_name, run_name) tuples.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
List: Combined dataset for the specified run IDs.
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
data = []
|
|
134
|
+
for session_name, run_name in tqdm(runIDs):
|
|
135
|
+
root = self.roots[session_name]
|
|
136
|
+
data.extend(io.load_training_data(
|
|
137
|
+
root, [run_name], self.voxel_size, self.tomo_algorithm,
|
|
138
|
+
self.target_name, self.target_session_id, self.target_user_id,
|
|
139
|
+
progress_update=False ))
|
|
140
|
+
self._check_max_label_value(data)
|
|
141
|
+
return data
|
|
142
|
+
|
|
143
|
+
def create_train_dataloaders(self, *args, **kwargs):
|
|
144
|
+
"""
|
|
145
|
+
Override data loading to fetch from multiple projects.
|
|
146
|
+
"""
|
|
147
|
+
my_crop_size = kwargs.get("crop_size", 96)
|
|
148
|
+
my_num_samples = kwargs.get("num_samples", 128)
|
|
149
|
+
|
|
150
|
+
# If reloads are disabled and loaders already exist, reuse them
|
|
151
|
+
if self.reload_frequency < 0 and (self.train_loader is not None) and (self.val_loader is not None):
|
|
152
|
+
return self.train_loader, self.val_loader
|
|
153
|
+
|
|
154
|
+
# Estimate Max Number of Threads with mp.cpu_count
|
|
155
|
+
n_procs = min(mp.cpu_count(), 4)
|
|
156
|
+
|
|
157
|
+
if self.train_loader is None:
|
|
158
|
+
# Fetch the next batch of run IDs
|
|
159
|
+
trainRunIDs = self._extract_run_ids('train_data_iter', self._initialize_train_iterators)
|
|
160
|
+
train_files = self._load_data(trainRunIDs)
|
|
161
|
+
|
|
162
|
+
# # Create the cached dataset with non-random transforms
|
|
163
|
+
train_ds = SmartCacheDataset(data=train_files, transform=augment.get_transforms(), cache_rate=0.5)
|
|
164
|
+
|
|
165
|
+
# # Delete the training files to free memory
|
|
166
|
+
train_files = None
|
|
167
|
+
gc.collect()
|
|
168
|
+
|
|
169
|
+
# Create the cached dataset with non-random transforms
|
|
170
|
+
# train_ds = cached_datset.MultiConfigCacheDataset(
|
|
171
|
+
# self, trainRunIDs, transform=augment.get_transforms(), cache_rate=1.0
|
|
172
|
+
# )
|
|
173
|
+
|
|
174
|
+
# I need to read (nx,ny,nz) and scale the crop size to make sure it isnt larger than nx.
|
|
175
|
+
if self.nx is None: (self.nx,self.ny,self.nz) = train_ds[0]['image'].shape[1:]
|
|
176
|
+
self.input_dim = io.get_input_dimensions(train_ds, my_crop_size)
|
|
177
|
+
|
|
178
|
+
# Wrap the cached dataset to apply random transforms during iteration
|
|
179
|
+
self.dynamic_train_dataset = dataset.DynamicDataset(
|
|
180
|
+
data=train_ds,
|
|
181
|
+
transform=augment.get_random_transforms(self.input_dim, my_num_samples, self.Nclasses)
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
self.train_loader = DataLoader(
|
|
185
|
+
self.dynamic_train_dataset,
|
|
186
|
+
batch_size=1,
|
|
187
|
+
shuffle=True,
|
|
188
|
+
num_workers=n_procs,
|
|
189
|
+
pin_memory=torch.cuda.is_available(),
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
else:
|
|
193
|
+
# Fetch the next batch of run IDs
|
|
194
|
+
trainRunIDs = self._extract_run_ids('train_data_iter', self._initialize_train_iterators)
|
|
195
|
+
train_files = self._load_data(trainRunIDs)
|
|
196
|
+
train_ds = CacheDataset(data=train_files, transform=augment.get_transforms(), cache_rate=1.0)
|
|
197
|
+
self.dynamic_train_dataset.update_data(train_ds)
|
|
198
|
+
|
|
199
|
+
# We Only Need to Reload the Validation Dataset if the Total Number of Runs is larger than
|
|
200
|
+
# the tomo batch size
|
|
201
|
+
if self.val_loader is None:
|
|
202
|
+
|
|
203
|
+
validateRunIDs = self._extract_run_ids('val_data_iter', self._initialize_val_iterators)
|
|
204
|
+
val_files = self._load_data(validateRunIDs)
|
|
205
|
+
|
|
206
|
+
# # Create validation dataset
|
|
207
|
+
val_ds = SmartCacheDataset(data=val_files, transform=augment.get_transforms(), cache_rate=1.0)
|
|
208
|
+
|
|
209
|
+
# # Delete the validation files to free memory
|
|
210
|
+
val_files = None
|
|
211
|
+
gc.collect()
|
|
212
|
+
|
|
213
|
+
# Create the cached dataset with non-random transforms
|
|
214
|
+
# val_ds = cached_datset.MultiConfigCacheDataset(
|
|
215
|
+
# self, validateRunIDs, transform=augment.get_transforms(), cache_rate=1.0
|
|
216
|
+
# )
|
|
217
|
+
|
|
218
|
+
# # I need to read (nx,ny,nz) and scale the crop size to make sure it isnt larger than nx.
|
|
219
|
+
# if self.nx is None:
|
|
220
|
+
# (self.nx,self.ny,self.nz) = val_ds[0]['image'].shape[1:]
|
|
221
|
+
|
|
222
|
+
# if crop_size > self.nx: self.input_dim = (self.nx, crop_size, crop_size)
|
|
223
|
+
# else: self.input_dim = (crop_size, crop_size, crop_size)
|
|
224
|
+
|
|
225
|
+
# Wrap the cached dataset to apply random transforms during iteration
|
|
226
|
+
self.dynamic_validation_dataset = dataset.DynamicDataset( data=val_ds )
|
|
227
|
+
|
|
228
|
+
# Create validation DataLoader
|
|
229
|
+
self.val_loader = DataLoader(
|
|
230
|
+
self.dynamic_validation_dataset,
|
|
231
|
+
batch_size=1,
|
|
232
|
+
num_workers=n_procs,
|
|
233
|
+
pin_memory=torch.cuda.is_available(),
|
|
234
|
+
shuffle=False, # Ensure the data order remains consistent
|
|
235
|
+
)
|
|
236
|
+
else:
|
|
237
|
+
validateRunIDs = self._extract_run_ids('val_data_iter', self._initialize_val_iterators)
|
|
238
|
+
val_files = self._load_data(validateRunIDs)
|
|
239
|
+
|
|
240
|
+
val_ds = CacheDataset(data=val_files, transform=augment.get_transforms(), cache_rate=1.0)
|
|
241
|
+
self.dynamic_validation_dataset.update_data(val_ds)
|
|
242
|
+
|
|
243
|
+
return self.train_loader, self.val_loader
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def tmp_return_datasets(self):
|
|
247
|
+
trainRunIDs = self._extract_run_ids('train_data_iter', self._initialize_train_iterators)
|
|
248
|
+
train_files = self._load_data(trainRunIDs)
|
|
249
|
+
|
|
250
|
+
validateRunIDs = self._extract_run_ids('val_data_iter', self._initialize_val_iterators)
|
|
251
|
+
val_files = self._load_data(validateRunIDs)
|
|
252
|
+
|
|
253
|
+
return train_files, val_files
|
|
File without changes
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
from octopi import utils
|
|
2
|
+
import argparse
|
|
3
|
+
|
|
4
|
+
def add_model_parameters(parser, octopi = False):
|
|
5
|
+
"""
|
|
6
|
+
Add common model parameters to the parser.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
# Add U-Net model parameters
|
|
10
|
+
parser.add_argument("--Nclass", type=int, required=False, default=3, help="Number of prediction classes in the model")
|
|
11
|
+
parser.add_argument("--channels", type=utils.parse_int_list, required=False, default='32,64,128,128', help="List of channel sizes")
|
|
12
|
+
parser.add_argument("--strides", type=utils.parse_int_list, required=False, default='2,2,1', help="List of stride sizes")
|
|
13
|
+
parser.add_argument("--res-units", type=int, required=False, default=2, help="Number of residual units in the UNet")
|
|
14
|
+
parser.add_argument("--dim-in", type=int, required=False, default=96, help="Input dimension for the UNet model")
|
|
15
|
+
|
|
16
|
+
def inference_model_parameters(parser):
|
|
17
|
+
"""
|
|
18
|
+
Add model parameters for inference.
|
|
19
|
+
"""
|
|
20
|
+
parser.add_argument("--model-config", type=str, required=True, help="Path to the model configuration file")
|
|
21
|
+
parser.add_argument("--model-weights", type=str, required=True, help="Path to the model weights file")
|
|
22
|
+
|
|
23
|
+
def add_train_parameters(parser, octopi = False):
|
|
24
|
+
"""
|
|
25
|
+
Add training parameters to the parser.
|
|
26
|
+
"""
|
|
27
|
+
parser.add_argument("--num-epochs", type=int, required=False, default=100, help="Number of training epochs")
|
|
28
|
+
parser.add_argument("--val-interval", type=int, required=False, default=10, help="Interval for validation metric calculations")
|
|
29
|
+
parser.add_argument("--tomo-batch-size", type=int, required=False, default=15, help="Number of tomograms to load per epoch for training")
|
|
30
|
+
parser.add_argument("--best-metric", type=str, default='avg_f1', required=False, help="Metric to Monitor for Determining Best Model. To track fBetaN, use fBetaN with N as the beta-value.")
|
|
31
|
+
|
|
32
|
+
if not octopi:
|
|
33
|
+
parser.add_argument("--num-tomo-crops", type=int, required=False, default=16, help="Number of tomogram crops to use per patch")
|
|
34
|
+
parser.add_argument("--lr", type=float, required=False, default=1e-3, help="Learning rate for the optimizer")
|
|
35
|
+
parser.add_argument("--tversky-alpha", type=float, required=False, default=0.5, help="Alpha parameter for the Tversky loss")
|
|
36
|
+
parser.add_argument("--model-save-path", required=True, help="Path to model save directory")
|
|
37
|
+
else:
|
|
38
|
+
parser.add_argument("--num-trials", type=int, default=10, required=False, help="Number of trials for architecture search (default: 10).")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def add_config(parser, single_config):
|
|
42
|
+
if single_config:
|
|
43
|
+
parser.add_argument("--config", type=str, required=True, help="Path to the configuration file.")
|
|
44
|
+
else:
|
|
45
|
+
parser.add_argument("--config", type=str, required=True, action='append',
|
|
46
|
+
help="Specify a single configuration path (/path/to/config.json) "
|
|
47
|
+
"or multiple entries in the format session_name,/path/to/config.json. "
|
|
48
|
+
"Use multiple --config entries for multiple sessions.")
|
|
49
|
+
parser.add_argument("--voxel-size", type=float, required=False, default=10, help="Voxel size of tomograms used")
|
|
50
|
+
|
|
51
|
+
def add_inference_parameters(parser):
|
|
52
|
+
|
|
53
|
+
parser.add_argument("--tomo-alg", required=False, default = 'wbp',
|
|
54
|
+
help="Tomogram algorithm used for produces segmentation prediction masks.")
|
|
55
|
+
parser.add_argument("--seg-info", type=utils.parse_target, required=False,
|
|
56
|
+
default='predict,octopi,1', help='Information Query to save Segmentation predictions under, e.g., (e.g., "name" or "name,user_id,session_id" - Default UserID is octopi and SessionID is 1')
|
|
57
|
+
parser.add_argument("--tomo-batch-size", type=int, default=25, required=False,
|
|
58
|
+
help="Batch size for tomogram processing.")
|
|
59
|
+
parser.add_argument("--run-ids", type=utils.parse_list, default=None, required=False,
|
|
60
|
+
help="List of run IDs for prediction, e.g., run1,run2 or [run1,run2]. If not provided, all available runs will be processed.")
|
|
61
|
+
|
|
62
|
+
def add_localize_parameters(parser):
|
|
63
|
+
|
|
64
|
+
parser.add_argument("--voxel-size", type=int, required=False, default=10, help="Voxel size")
|
|
65
|
+
parser.add_argument("--method", type=str,required=False, default='watershed', help="Localization method")
|
|
66
|
+
parser.add_argument("--pick-session-id", required=False, default="1", type=str, help="Pick session ID")
|
|
67
|
+
parser.add_argument("--pick-objects", required=True, type=str, help="Pick objects")
|
|
68
|
+
parser.add_argument("--seg-info", required=True, type=str, help="Segmentation info")
|
|
69
|
+
|
|
70
|
+
def add_slurm_parameters(parser, base_job_name, gpus = 1):
|
|
71
|
+
"""
|
|
72
|
+
Add SLURM job parameters to the parser.
|
|
73
|
+
"""
|
|
74
|
+
parser.add_argument("--conda-env", type=str, required=False, default='/hpc/projects/group.czii/conda_environments/pyUNET/', help="Path to Conda environment")
|
|
75
|
+
parser.add_argument("--job-name", type=str, required=False, default=f'{base_job_name}', help="Job name for SLURM job")
|
|
76
|
+
|
|
77
|
+
if gpus > 0:
|
|
78
|
+
parser.add_argument("--gpu-constraint", type=str.lower, choices=['a6000', 'a100', 'h100', 'h200'], required=False, default='h100', help="GPU constraint")
|
|
79
|
+
if gpus > 1:
|
|
80
|
+
parser.add_argument("--num-gpus", type=int, required=False, default=1, help="Number of GPUs to use")
|
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
from octopi.entry_points import run_train, run_segment_predict, run_localize, run_optuna
|
|
2
|
+
from octopi.submit_slurm import create_shellsubmit, create_multiconfig_shellsubmit
|
|
3
|
+
from octopi.processing.importers import cli_mrcs_parser, cli_dataportal_parser
|
|
4
|
+
from octopi.entry_points import common
|
|
5
|
+
from octopi import utils
|
|
6
|
+
import argparse
|
|
7
|
+
|
|
8
|
+
def create_train_script(args):
|
|
9
|
+
"""
|
|
10
|
+
Create a SLURM script for training 3D CNN models
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
strconfigs = ""
|
|
14
|
+
for config in args.config:
|
|
15
|
+
strconfigs += f"--config {config}"
|
|
16
|
+
|
|
17
|
+
command = f"""
|
|
18
|
+
octopi train \\
|
|
19
|
+
--model-save-path {args.model_save_path} \\
|
|
20
|
+
--target-info {args.target_info} \\
|
|
21
|
+
--voxel-size {args.voxel_size} --tomo-algorithm {args.tomo_algorithm} --Nclass {args.Nclass} \\
|
|
22
|
+
--best-metric {args.best_metric} --num-epochs {args.num_epochs} --val-interval {args.val_interval} \\
|
|
23
|
+
--tomo-batch-size {args.tomo_batch_size} --num-tomo-crops {args.num_tomo_crops} \\
|
|
24
|
+
{strconfigs}
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
# If a model config is provided, use it to build the model
|
|
28
|
+
if args.model_config is not None:
|
|
29
|
+
command += f" --model-config {args.model_config}"
|
|
30
|
+
else:
|
|
31
|
+
command += f" --tversky-alpha {args.tversky_alpha} --channels {args.channels} --strides {args.strides} --dim-in {args.dim_in} --res-units {args.res_units}"
|
|
32
|
+
|
|
33
|
+
# If Model Weights are provided, use them to initialize the model
|
|
34
|
+
if args.model_weights is not None and args.model_config is not None:
|
|
35
|
+
command += f" --model-weights {args.model_weights}"
|
|
36
|
+
|
|
37
|
+
create_shellsubmit(
|
|
38
|
+
job_name = args.job_name,
|
|
39
|
+
output_file = 'trainer.log',
|
|
40
|
+
shell_name = 'train_octopi.sh',
|
|
41
|
+
conda_path = args.conda_env,
|
|
42
|
+
command = command,
|
|
43
|
+
num_gpus = 1,
|
|
44
|
+
gpu_constraint = args.gpu_constraint
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
def train_model_slurm():
|
|
48
|
+
"""
|
|
49
|
+
Create a SLURM script for training 3D CNN models
|
|
50
|
+
"""
|
|
51
|
+
parser_description = "Create a SLURM script for training 3D CNN models"
|
|
52
|
+
args = run_train.train_model_parser(parser_description, add_slurm=True)
|
|
53
|
+
create_train_script(args)
|
|
54
|
+
|
|
55
|
+
def create_model_explore_script(args):
|
|
56
|
+
"""
|
|
57
|
+
Create a SLURM script for running bayesian optimization on 3D CNN models
|
|
58
|
+
"""
|
|
59
|
+
strconfigs = ""
|
|
60
|
+
for config in args.config:
|
|
61
|
+
strconfigs += f"--config {config}"
|
|
62
|
+
|
|
63
|
+
command = f"""
|
|
64
|
+
octopi model-explore \\
|
|
65
|
+
--model-type {args.model_type} --model-save-path {args.model_save_path} \\
|
|
66
|
+
--voxel-size {args.voxel_size} --tomo-alg {args.tomo_alg} --Nclass {args.Nclass} \\
|
|
67
|
+
--val-interval {args.val_interval} --num-epochs {args.num_epochs} --num-trials {args.num_trials} \\
|
|
68
|
+
--best-metric {args.best_metric} --mlflow-experiment-name {args.mlflow_experiment_name} \\
|
|
69
|
+
--target-name {args.target_name} --target-session-id {args.target_session_id} --target-user-id {args.target_user_id} \\
|
|
70
|
+
{strconfigs}
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
create_shellsubmit(
|
|
74
|
+
job_name = args.job_name,
|
|
75
|
+
output_file = 'optuna.log',
|
|
76
|
+
shell_name = 'model_explore.sh',
|
|
77
|
+
conda_path = args.conda_env,
|
|
78
|
+
command = command,
|
|
79
|
+
num_gpus = 1,
|
|
80
|
+
gpu_constraint = args.gpu_constraint
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
def model_explore_slurm():
|
|
84
|
+
"""
|
|
85
|
+
Create a SLURM script for running bayesian optimization on 3D CNN models
|
|
86
|
+
"""
|
|
87
|
+
parser_description = "Create a SLURM script for running bayesian optimization on 3D CNN models"
|
|
88
|
+
args = run_optuna.optuna_parser(parser_description, add_slurm=True)
|
|
89
|
+
create_model_explore_script(args)
|
|
90
|
+
|
|
91
|
+
def create_inference_script(args):
|
|
92
|
+
"""
|
|
93
|
+
Create a SLURM script for running inference on 3D CNN models
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
if len(args.config.split(',')) > 1:
|
|
97
|
+
|
|
98
|
+
create_multiconfig_shellsubmit(
|
|
99
|
+
job_name = args.job_name,
|
|
100
|
+
output_file = 'predict.log',
|
|
101
|
+
shell_name = 'segment.sh',
|
|
102
|
+
conda_path = args.conda_env,
|
|
103
|
+
base_inputs = args.base_inputs,
|
|
104
|
+
config_inputs = args.config_inputs,
|
|
105
|
+
command = args.command,
|
|
106
|
+
num_gpus = args.num_gpus,
|
|
107
|
+
gpu_constraint = args.gpu_constraint
|
|
108
|
+
)
|
|
109
|
+
else:
|
|
110
|
+
|
|
111
|
+
command = f"""
|
|
112
|
+
octopi inference \\
|
|
113
|
+
--config {args.config} \\
|
|
114
|
+
--seg-info {",".join(args.seg_info)} \\
|
|
115
|
+
--model-weights {args.model_weights} \\
|
|
116
|
+
--dim-in {args.dim_in} --res-units {args.res_units} \\
|
|
117
|
+
--model-type {args.model_type} --channels {",".join(map(str, args.channels))} --strides {",".join(map(str, args.strides))} \\
|
|
118
|
+
--voxel-size {args.voxel_size} --tomo-alg {args.tomo_alg} --Nclass {args.Nclass}
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
create_shellsubmit(
|
|
122
|
+
job_name = args.job_name,
|
|
123
|
+
output_file = 'predict.log',
|
|
124
|
+
shell_name = 'segment.sh',
|
|
125
|
+
conda_path = args.conda_env,
|
|
126
|
+
command = command,
|
|
127
|
+
num_gpus = 1,
|
|
128
|
+
gpu_constraint = args.gpu_constraint
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
def inference_slurm():
|
|
132
|
+
"""
|
|
133
|
+
Create a SLURM script for running segmentation predictions with a specified model and configuration on CryoET Tomograms.
|
|
134
|
+
"""
|
|
135
|
+
parser_description = "Create a SLURM script for running segmentation predictions with a specified model and configuration on CryoET Tomograms."
|
|
136
|
+
args = run_segment_predict.inference_parser(parser_description, add_slurm=True)
|
|
137
|
+
create_inference_script(args)
|
|
138
|
+
|
|
139
|
+
def create_localize_script(args):
|
|
140
|
+
""""
|
|
141
|
+
Create a SLURM script for running localization on predicted segmentation masks
|
|
142
|
+
"""
|
|
143
|
+
if len(args.config.split(',')) > 1:
|
|
144
|
+
|
|
145
|
+
create_multiconfig_shellsubmit(
|
|
146
|
+
job_name = args.job_name,
|
|
147
|
+
output_file = args.output,
|
|
148
|
+
shell_name = args.output_script,
|
|
149
|
+
conda_path = args.conda_env,
|
|
150
|
+
base_inputs = args.base_inputs,
|
|
151
|
+
config_inputs = args.config_inputs,
|
|
152
|
+
command = args.command
|
|
153
|
+
)
|
|
154
|
+
else:
|
|
155
|
+
|
|
156
|
+
command = f"""
|
|
157
|
+
octopi localize \\
|
|
158
|
+
--config {args.config} \\
|
|
159
|
+
--voxel-size {args.voxel_size} --pick-session-id {args.pick_session_id} --pick-user-id {args.pick_user_id} \\
|
|
160
|
+
--method {args.method} --seg-info {",".join(args.seg_info)} \\
|
|
161
|
+
"""
|
|
162
|
+
if args.pick_objects is not None:
|
|
163
|
+
command += f" --pick-objects {args.pick_objects}"
|
|
164
|
+
|
|
165
|
+
create_shellsubmit(
|
|
166
|
+
job_name = args.job_name,
|
|
167
|
+
output_file = 'localize.log',
|
|
168
|
+
shell_name = 'localize.sh',
|
|
169
|
+
conda_path = args.conda_env,
|
|
170
|
+
command = command,
|
|
171
|
+
num_gpus = 0
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
def localize_slurm():
|
|
175
|
+
"""
|
|
176
|
+
Create a SLURM script for running localization on predicted segmentation masks
|
|
177
|
+
"""
|
|
178
|
+
parser_description = "Create a SLURM script for localization on predicted segmentation masks"
|
|
179
|
+
args = run_localize.localize_parser(parser_description, add_slurm=True)
|
|
180
|
+
create_localize_script(args)
|
|
181
|
+
|
|
182
|
+
def create_extract_mb_picks_script(args):
|
|
183
|
+
pass
|
|
184
|
+
|
|
185
|
+
def extract_mb_picks_slurm():
|
|
186
|
+
pass
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def create_import_mrc_script(args):
|
|
190
|
+
"""
|
|
191
|
+
Create a SLURM script for importing mrc volumes and potentialy downsampling
|
|
192
|
+
"""
|
|
193
|
+
command = f"""
|
|
194
|
+
octopi import-mrc-volumes \\
|
|
195
|
+
--mrcs-path {args.mrcs_path} \\
|
|
196
|
+
--config {args.config} --target-tomo-type {args.target_tomo_type} \\
|
|
197
|
+
--input-voxel-size {args.input_voxel_size} --output-voxel-size {args.output_voxel_size}
|
|
198
|
+
"""
|
|
199
|
+
|
|
200
|
+
create_shellsubmit(
|
|
201
|
+
job_name = args.job_name,
|
|
202
|
+
output_file = 'importer.log',
|
|
203
|
+
shell_name = 'mrc_importer.sh',
|
|
204
|
+
conda_path = args.conda_env,
|
|
205
|
+
command = command
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
def import_mrc_slurm():
|
|
209
|
+
"""
|
|
210
|
+
Create a SLURM script for importing mrc volumes and potentialy downsampling
|
|
211
|
+
"""
|
|
212
|
+
parser_description = "Create a SLURM script for importing mrc volumes and potentialy downsampling"
|
|
213
|
+
args = cli_mrcs_parser(parser_description, add_slurm=True)
|
|
214
|
+
create_import_mrc_script(args)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def create_download_dataportal_script(args):
|
|
218
|
+
"""
|
|
219
|
+
Create a SLURM script for downloading tomograms from the Dataportal
|
|
220
|
+
"""
|
|
221
|
+
command = f"""
|
|
222
|
+
octopi download-dataportal \\
|
|
223
|
+
--config {args.config} --datasetID {args.datasetID} \\
|
|
224
|
+
--overlay-path {args.overlay_path}
|
|
225
|
+
--dataportal-name {args.dataportal_name} --target-tomo-type {args.target_tomo_type} \\
|
|
226
|
+
--input-voxel-size {args.input_voxel_size} --output-voxel-size {args.output_voxel_size}
|
|
227
|
+
"""
|
|
228
|
+
|
|
229
|
+
create_shellsubmit(
|
|
230
|
+
job_name = args.job_name,
|
|
231
|
+
output_file = 'importer.log',
|
|
232
|
+
shell_name = 'dataportal_importer.sh',
|
|
233
|
+
conda_path = args.conda_env,
|
|
234
|
+
command = command
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
def download_dataportal_slurm():
|
|
238
|
+
"""
|
|
239
|
+
Create a SLURM script for downloading tomograms from the Dataportal
|
|
240
|
+
"""
|
|
241
|
+
parser_description = "Create a SLURM script for downloading tomograms from the Dataportal"
|
|
242
|
+
args = cli_dataportal_parser(parser_description, add_slurm=True)
|
|
243
|
+
create_download_dataportal_script(args)
|