octopi 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of octopi might be problematic. Click here for more details.

Files changed (59) hide show
  1. octopi/__init__.py +0 -0
  2. octopi/datasets/__init__.py +0 -0
  3. octopi/datasets/augment.py +84 -0
  4. octopi/datasets/cached_datset.py +113 -0
  5. octopi/datasets/dataset.py +19 -0
  6. octopi/datasets/generators.py +429 -0
  7. octopi/datasets/mixup.py +49 -0
  8. octopi/datasets/multi_config_generator.py +253 -0
  9. octopi/entry_points/__init__.py +0 -0
  10. octopi/entry_points/common.py +80 -0
  11. octopi/entry_points/create_slurm_submission.py +243 -0
  12. octopi/entry_points/run_create_targets.py +281 -0
  13. octopi/entry_points/run_evaluate.py +65 -0
  14. octopi/entry_points/run_extract_mb_picks.py +141 -0
  15. octopi/entry_points/run_extract_midpoint.py +143 -0
  16. octopi/entry_points/run_localize.py +222 -0
  17. octopi/entry_points/run_optuna.py +139 -0
  18. octopi/entry_points/run_segment_predict.py +166 -0
  19. octopi/entry_points/run_train.py +201 -0
  20. octopi/extract/__init__.py +0 -0
  21. octopi/extract/localize.py +254 -0
  22. octopi/extract/membranebound_extract.py +262 -0
  23. octopi/extract/midpoint_extract.py +193 -0
  24. octopi/io.py +457 -0
  25. octopi/losses.py +86 -0
  26. octopi/main.py +101 -0
  27. octopi/models/AttentionUnet.py +56 -0
  28. octopi/models/MedNeXt.py +111 -0
  29. octopi/models/ModelTemplate.py +36 -0
  30. octopi/models/SegResNet.py +92 -0
  31. octopi/models/Unet.py +59 -0
  32. octopi/models/UnetPlusPlus.py +47 -0
  33. octopi/models/__init__.py +0 -0
  34. octopi/models/common.py +62 -0
  35. octopi/processing/__init__.py +0 -0
  36. octopi/processing/create_targets_from_picks.py +106 -0
  37. octopi/processing/downsample.py +129 -0
  38. octopi/processing/evaluate.py +289 -0
  39. octopi/processing/importers.py +213 -0
  40. octopi/processing/my_metrics.py +26 -0
  41. octopi/processing/segmentation_from_picks.py +167 -0
  42. octopi/processing/writers.py +102 -0
  43. octopi/pytorch/__init__.py +0 -0
  44. octopi/pytorch/hyper_search.py +243 -0
  45. octopi/pytorch/model_search_submitter.py +290 -0
  46. octopi/pytorch/segmentation.py +317 -0
  47. octopi/pytorch/trainer.py +438 -0
  48. octopi/pytorch_lightning/__init__.py +0 -0
  49. octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
  50. octopi/pytorch_lightning/train_pl.py +244 -0
  51. octopi/stopping_criteria.py +143 -0
  52. octopi/submit_slurm.py +95 -0
  53. octopi/utils.py +238 -0
  54. octopi/visualization_tools.py +201 -0
  55. octopi-1.0.dist-info/LICENSE +41 -0
  56. octopi-1.0.dist-info/METADATA +209 -0
  57. octopi-1.0.dist-info/RECORD +59 -0
  58. octopi-1.0.dist-info/WHEEL +4 -0
  59. octopi-1.0.dist-info/entry_points.txt +4 -0
@@ -0,0 +1,49 @@
1
+ from monai.transforms import Transform
2
+ from torch.distributions import Beta
3
+ from torch import nn
4
+ import numpy as np
5
+ import torch
6
+
7
+ class MixupTransformd(Transform):
8
+ """
9
+ A dictionary-based wrapper for Mixup augmentation that applies to batches.
10
+ This needs to be applied after batching, typically directly in the training loop.
11
+ """
12
+ def __init__(self, keys=("image", "label"), mix_beta=0.2, mixadd=False, prob=0.5):
13
+ self.keys = keys
14
+ self.mixup = Mixup(mix_beta=mix_beta, mixadd=mixadd)
15
+ self.prob = prob
16
+
17
+ def __call__(self, data):
18
+ d = dict(data)
19
+ if np.random.random() < self.prob: # Apply with probability
20
+ d[self.keys[0]], d[self.keys[1]] = self.mixup(d[self.keys[0]], d[self.keys[1]])
21
+ return d
22
+
23
+ class Mixup(nn.Module):
24
+ def __init__(self, mix_beta, mixadd=False):
25
+
26
+ super(Mixup, self).__init__()
27
+ self.beta_distribution = Beta(mix_beta, mix_beta)
28
+ self.mixadd = mixadd
29
+
30
+ def forward(self, X, Y, Z=None):
31
+
32
+ bs = X.shape[0]
33
+ n_dims = len(X.shape)
34
+ perm = torch.randperm(bs)
35
+ coeffs = self.beta_distribution.rsample(torch.Size((bs,))).to(X.device)
36
+ X_coeffs = coeffs.view((-1,) + (1,)*(X.ndim-1))
37
+ Y_coeffs = coeffs.view((-1,) + (1,)*(Y.ndim-1))
38
+
39
+ X = X_coeffs * X + (1-X_coeffs) * X[perm]
40
+
41
+ if self.mixadd:
42
+ Y = (Y + Y[perm]).clip(0, 1)
43
+ else:
44
+ Y = Y_coeffs * Y + (1 - Y_coeffs) * Y[perm]
45
+
46
+ if Z:
47
+ return X, Y, Z
48
+
49
+ return X, Y
@@ -0,0 +1,253 @@
1
+ from octopi.datasets import dataset, augment, cached_datset
2
+ from octopi.datasets.generators import TrainLoaderManager
3
+ from monai.data import DataLoader, SmartCacheDataset, CacheDataset, Dataset
4
+ from octopi import io
5
+ import multiprocess as mp
6
+ from typing import List
7
+ from tqdm import tqdm
8
+ import torch, gc
9
+
10
+ class MultiConfigTrainLoaderManager(TrainLoaderManager):
11
+
12
+ def __init__(self,
13
+ configs: dict, # Dictionary of session names and config paths
14
+ target_name: str,
15
+ target_session_id: str = None,
16
+ target_user_id: str = None,
17
+ voxel_size: float = 10,
18
+ tomo_algorithm: List[str] = ['wbp'],
19
+ tomo_batch_size: int = 15,
20
+ Nclasses: int = 3):
21
+ """
22
+ Initialize MultiConfigTrainLoaderManager with multiple configs.
23
+
24
+ Args:
25
+ configs (list): List of config file paths.
26
+ Other arguments are inherited from TrainLoaderManager.
27
+ """
28
+
29
+ # Initialize shared attributes manually (skip super().__init__ to avoid invalid config handling)
30
+ self.config = configs
31
+ self.roots = {name: io.load_copick_config(path) for name, path in configs.items()}
32
+
33
+ # Target and algorithm parameters
34
+ self.target_name = target_name
35
+ self.target_session_id = target_session_id
36
+ self.target_user_id = target_user_id
37
+ self.voxel_size = voxel_size
38
+ self.tomo_algorithm = tomo_algorithm
39
+
40
+ # Data management parameters
41
+ self.Nclasses = Nclasses
42
+ self.tomo_batch_size = tomo_batch_size
43
+ self.reload_training_dataset = True
44
+ self.reload_validation_dataset = True
45
+ self.val_loader = None
46
+ self.train_loader = None
47
+
48
+ # Initialize Run IDs placeholder
49
+ self.myRunIDs = {}
50
+
51
+ # Initialize the input dimensions
52
+ self.nx = None
53
+ self.ny = None
54
+ self.nz = None
55
+
56
+ def get_available_runIDs(self):
57
+ """
58
+ Identify and return a combined list of run IDs with available segmentations
59
+ across all configured CoPick projects.
60
+
61
+ Returns:
62
+ List of tuples: [(session_name, run_name), ...]
63
+ """
64
+ available_runIDs = []
65
+ for name, root in self.roots.items():
66
+ runIDs = [run.name for run in root.runs]
67
+ for run in runIDs:
68
+ run = root.get_run(run)
69
+ seg = run.get_segmentations(
70
+ name=self.target_name,
71
+ session_id=self.target_session_id,
72
+ user_id=self.target_user_id,
73
+ voxel_size=float(self.voxel_size)
74
+ )
75
+ if len(seg) > 0:
76
+ available_runIDs.append((name, run.name)) # Include session name for disambiguation
77
+
78
+ # If No Segmentations are Found, Inform the User
79
+ if len(available_runIDs) == 0:
80
+ print(
81
+ f"[Error] No segmentations found for the target query:\n"
82
+ f"TargetName: {self.target_name}, UserID: {self.target_user_id}, "
83
+ f"SessionID: {self.target_session_id}\n"
84
+ f"Please check the target name, user ID, and session ID.\n"
85
+ )
86
+ exit()
87
+
88
+ return available_runIDs
89
+
90
+ def get_data_splits(self,
91
+ trainRunIDs: str = None,
92
+ validateRunIDs: str = None,
93
+ train_ratio: float = 0.8,
94
+ val_ratio: float = 0.1,
95
+ test_ratio: float = 0.1,
96
+ create_test_dataset: bool = True):
97
+ """
98
+ Override to handle run IDs as (session_name, run_name) tuples.
99
+ """
100
+ # Use the get_available_runIDs method to handle multiple projects
101
+ runIDs = self.get_available_runIDs()
102
+ return super().get_data_splits(trainRunIDs = runIDs,
103
+ train_ratio = train_ratio,
104
+ val_ratio = val_ratio,
105
+ test_ratio = test_ratio,
106
+ create_test_dataset = create_test_dataset)
107
+
108
+ def _initialize_train_iterators(self):
109
+ """
110
+ Initialize the training data iterators with multi-config support.
111
+ """
112
+ self.padded_train_list = self._get_padded_list(self.myRunIDs['train'], self.train_batch_size)
113
+ self.train_data_iter = iter(self._get_data_batches(self.padded_train_list, self.train_batch_size))
114
+
115
+ def _initialize_val_iterators(self):
116
+ """
117
+ Initialize the validation data iterators with multi-config support.
118
+ """
119
+ self.padded_val_list = self._get_padded_list(self.myRunIDs['validate'], self.val_batch_size)
120
+ self.val_data_iter = iter(self._get_data_batches(self.padded_val_list, self.val_batch_size))
121
+
122
+ def _load_data(self, runIDs):
123
+ """
124
+ Load data from multiple CoPick projects for given run IDs.
125
+
126
+ Args:
127
+ runIDs (list): List of (session_name, run_name) tuples.
128
+
129
+ Returns:
130
+ List: Combined dataset for the specified run IDs.
131
+ """
132
+
133
+ data = []
134
+ for session_name, run_name in tqdm(runIDs):
135
+ root = self.roots[session_name]
136
+ data.extend(io.load_training_data(
137
+ root, [run_name], self.voxel_size, self.tomo_algorithm,
138
+ self.target_name, self.target_session_id, self.target_user_id,
139
+ progress_update=False ))
140
+ self._check_max_label_value(data)
141
+ return data
142
+
143
+ def create_train_dataloaders(self, *args, **kwargs):
144
+ """
145
+ Override data loading to fetch from multiple projects.
146
+ """
147
+ my_crop_size = kwargs.get("crop_size", 96)
148
+ my_num_samples = kwargs.get("num_samples", 128)
149
+
150
+ # If reloads are disabled and loaders already exist, reuse them
151
+ if self.reload_frequency < 0 and (self.train_loader is not None) and (self.val_loader is not None):
152
+ return self.train_loader, self.val_loader
153
+
154
+ # Estimate Max Number of Threads with mp.cpu_count
155
+ n_procs = min(mp.cpu_count(), 4)
156
+
157
+ if self.train_loader is None:
158
+ # Fetch the next batch of run IDs
159
+ trainRunIDs = self._extract_run_ids('train_data_iter', self._initialize_train_iterators)
160
+ train_files = self._load_data(trainRunIDs)
161
+
162
+ # # Create the cached dataset with non-random transforms
163
+ train_ds = SmartCacheDataset(data=train_files, transform=augment.get_transforms(), cache_rate=0.5)
164
+
165
+ # # Delete the training files to free memory
166
+ train_files = None
167
+ gc.collect()
168
+
169
+ # Create the cached dataset with non-random transforms
170
+ # train_ds = cached_datset.MultiConfigCacheDataset(
171
+ # self, trainRunIDs, transform=augment.get_transforms(), cache_rate=1.0
172
+ # )
173
+
174
+ # I need to read (nx,ny,nz) and scale the crop size to make sure it isnt larger than nx.
175
+ if self.nx is None: (self.nx,self.ny,self.nz) = train_ds[0]['image'].shape[1:]
176
+ self.input_dim = io.get_input_dimensions(train_ds, my_crop_size)
177
+
178
+ # Wrap the cached dataset to apply random transforms during iteration
179
+ self.dynamic_train_dataset = dataset.DynamicDataset(
180
+ data=train_ds,
181
+ transform=augment.get_random_transforms(self.input_dim, my_num_samples, self.Nclasses)
182
+ )
183
+
184
+ self.train_loader = DataLoader(
185
+ self.dynamic_train_dataset,
186
+ batch_size=1,
187
+ shuffle=True,
188
+ num_workers=n_procs,
189
+ pin_memory=torch.cuda.is_available(),
190
+ )
191
+
192
+ else:
193
+ # Fetch the next batch of run IDs
194
+ trainRunIDs = self._extract_run_ids('train_data_iter', self._initialize_train_iterators)
195
+ train_files = self._load_data(trainRunIDs)
196
+ train_ds = CacheDataset(data=train_files, transform=augment.get_transforms(), cache_rate=1.0)
197
+ self.dynamic_train_dataset.update_data(train_ds)
198
+
199
+ # We Only Need to Reload the Validation Dataset if the Total Number of Runs is larger than
200
+ # the tomo batch size
201
+ if self.val_loader is None:
202
+
203
+ validateRunIDs = self._extract_run_ids('val_data_iter', self._initialize_val_iterators)
204
+ val_files = self._load_data(validateRunIDs)
205
+
206
+ # # Create validation dataset
207
+ val_ds = SmartCacheDataset(data=val_files, transform=augment.get_transforms(), cache_rate=1.0)
208
+
209
+ # # Delete the validation files to free memory
210
+ val_files = None
211
+ gc.collect()
212
+
213
+ # Create the cached dataset with non-random transforms
214
+ # val_ds = cached_datset.MultiConfigCacheDataset(
215
+ # self, validateRunIDs, transform=augment.get_transforms(), cache_rate=1.0
216
+ # )
217
+
218
+ # # I need to read (nx,ny,nz) and scale the crop size to make sure it isnt larger than nx.
219
+ # if self.nx is None:
220
+ # (self.nx,self.ny,self.nz) = val_ds[0]['image'].shape[1:]
221
+
222
+ # if crop_size > self.nx: self.input_dim = (self.nx, crop_size, crop_size)
223
+ # else: self.input_dim = (crop_size, crop_size, crop_size)
224
+
225
+ # Wrap the cached dataset to apply random transforms during iteration
226
+ self.dynamic_validation_dataset = dataset.DynamicDataset( data=val_ds )
227
+
228
+ # Create validation DataLoader
229
+ self.val_loader = DataLoader(
230
+ self.dynamic_validation_dataset,
231
+ batch_size=1,
232
+ num_workers=n_procs,
233
+ pin_memory=torch.cuda.is_available(),
234
+ shuffle=False, # Ensure the data order remains consistent
235
+ )
236
+ else:
237
+ validateRunIDs = self._extract_run_ids('val_data_iter', self._initialize_val_iterators)
238
+ val_files = self._load_data(validateRunIDs)
239
+
240
+ val_ds = CacheDataset(data=val_files, transform=augment.get_transforms(), cache_rate=1.0)
241
+ self.dynamic_validation_dataset.update_data(val_ds)
242
+
243
+ return self.train_loader, self.val_loader
244
+
245
+
246
+ def tmp_return_datasets(self):
247
+ trainRunIDs = self._extract_run_ids('train_data_iter', self._initialize_train_iterators)
248
+ train_files = self._load_data(trainRunIDs)
249
+
250
+ validateRunIDs = self._extract_run_ids('val_data_iter', self._initialize_val_iterators)
251
+ val_files = self._load_data(validateRunIDs)
252
+
253
+ return train_files, val_files
File without changes
@@ -0,0 +1,80 @@
1
+ from octopi import utils
2
+ import argparse
3
+
4
+ def add_model_parameters(parser, octopi = False):
5
+ """
6
+ Add common model parameters to the parser.
7
+ """
8
+
9
+ # Add U-Net model parameters
10
+ parser.add_argument("--Nclass", type=int, required=False, default=3, help="Number of prediction classes in the model")
11
+ parser.add_argument("--channels", type=utils.parse_int_list, required=False, default='32,64,128,128', help="List of channel sizes")
12
+ parser.add_argument("--strides", type=utils.parse_int_list, required=False, default='2,2,1', help="List of stride sizes")
13
+ parser.add_argument("--res-units", type=int, required=False, default=2, help="Number of residual units in the UNet")
14
+ parser.add_argument("--dim-in", type=int, required=False, default=96, help="Input dimension for the UNet model")
15
+
16
+ def inference_model_parameters(parser):
17
+ """
18
+ Add model parameters for inference.
19
+ """
20
+ parser.add_argument("--model-config", type=str, required=True, help="Path to the model configuration file")
21
+ parser.add_argument("--model-weights", type=str, required=True, help="Path to the model weights file")
22
+
23
+ def add_train_parameters(parser, octopi = False):
24
+ """
25
+ Add training parameters to the parser.
26
+ """
27
+ parser.add_argument("--num-epochs", type=int, required=False, default=100, help="Number of training epochs")
28
+ parser.add_argument("--val-interval", type=int, required=False, default=10, help="Interval for validation metric calculations")
29
+ parser.add_argument("--tomo-batch-size", type=int, required=False, default=15, help="Number of tomograms to load per epoch for training")
30
+ parser.add_argument("--best-metric", type=str, default='avg_f1', required=False, help="Metric to Monitor for Determining Best Model. To track fBetaN, use fBetaN with N as the beta-value.")
31
+
32
+ if not octopi:
33
+ parser.add_argument("--num-tomo-crops", type=int, required=False, default=16, help="Number of tomogram crops to use per patch")
34
+ parser.add_argument("--lr", type=float, required=False, default=1e-3, help="Learning rate for the optimizer")
35
+ parser.add_argument("--tversky-alpha", type=float, required=False, default=0.5, help="Alpha parameter for the Tversky loss")
36
+ parser.add_argument("--model-save-path", required=True, help="Path to model save directory")
37
+ else:
38
+ parser.add_argument("--num-trials", type=int, default=10, required=False, help="Number of trials for architecture search (default: 10).")
39
+
40
+
41
+ def add_config(parser, single_config):
42
+ if single_config:
43
+ parser.add_argument("--config", type=str, required=True, help="Path to the configuration file.")
44
+ else:
45
+ parser.add_argument("--config", type=str, required=True, action='append',
46
+ help="Specify a single configuration path (/path/to/config.json) "
47
+ "or multiple entries in the format session_name,/path/to/config.json. "
48
+ "Use multiple --config entries for multiple sessions.")
49
+ parser.add_argument("--voxel-size", type=float, required=False, default=10, help="Voxel size of tomograms used")
50
+
51
+ def add_inference_parameters(parser):
52
+
53
+ parser.add_argument("--tomo-alg", required=False, default = 'wbp',
54
+ help="Tomogram algorithm used for produces segmentation prediction masks.")
55
+ parser.add_argument("--seg-info", type=utils.parse_target, required=False,
56
+ default='predict,octopi,1', help='Information Query to save Segmentation predictions under, e.g., (e.g., "name" or "name,user_id,session_id" - Default UserID is octopi and SessionID is 1')
57
+ parser.add_argument("--tomo-batch-size", type=int, default=25, required=False,
58
+ help="Batch size for tomogram processing.")
59
+ parser.add_argument("--run-ids", type=utils.parse_list, default=None, required=False,
60
+ help="List of run IDs for prediction, e.g., run1,run2 or [run1,run2]. If not provided, all available runs will be processed.")
61
+
62
+ def add_localize_parameters(parser):
63
+
64
+ parser.add_argument("--voxel-size", type=int, required=False, default=10, help="Voxel size")
65
+ parser.add_argument("--method", type=str,required=False, default='watershed', help="Localization method")
66
+ parser.add_argument("--pick-session-id", required=False, default="1", type=str, help="Pick session ID")
67
+ parser.add_argument("--pick-objects", required=True, type=str, help="Pick objects")
68
+ parser.add_argument("--seg-info", required=True, type=str, help="Segmentation info")
69
+
70
+ def add_slurm_parameters(parser, base_job_name, gpus = 1):
71
+ """
72
+ Add SLURM job parameters to the parser.
73
+ """
74
+ parser.add_argument("--conda-env", type=str, required=False, default='/hpc/projects/group.czii/conda_environments/pyUNET/', help="Path to Conda environment")
75
+ parser.add_argument("--job-name", type=str, required=False, default=f'{base_job_name}', help="Job name for SLURM job")
76
+
77
+ if gpus > 0:
78
+ parser.add_argument("--gpu-constraint", type=str.lower, choices=['a6000', 'a100', 'h100', 'h200'], required=False, default='h100', help="GPU constraint")
79
+ if gpus > 1:
80
+ parser.add_argument("--num-gpus", type=int, required=False, default=1, help="Number of GPUs to use")
@@ -0,0 +1,243 @@
1
+ from octopi.entry_points import run_train, run_segment_predict, run_localize, run_optuna
2
+ from octopi.submit_slurm import create_shellsubmit, create_multiconfig_shellsubmit
3
+ from octopi.processing.importers import cli_mrcs_parser, cli_dataportal_parser
4
+ from octopi.entry_points import common
5
+ from octopi import utils
6
+ import argparse
7
+
8
+ def create_train_script(args):
9
+ """
10
+ Create a SLURM script for training 3D CNN models
11
+ """
12
+
13
+ strconfigs = ""
14
+ for config in args.config:
15
+ strconfigs += f"--config {config}"
16
+
17
+ command = f"""
18
+ octopi train \\
19
+ --model-save-path {args.model_save_path} \\
20
+ --target-info {args.target_info} \\
21
+ --voxel-size {args.voxel_size} --tomo-algorithm {args.tomo_algorithm} --Nclass {args.Nclass} \\
22
+ --best-metric {args.best_metric} --num-epochs {args.num_epochs} --val-interval {args.val_interval} \\
23
+ --tomo-batch-size {args.tomo_batch_size} --num-tomo-crops {args.num_tomo_crops} \\
24
+ {strconfigs}
25
+ """
26
+
27
+ # If a model config is provided, use it to build the model
28
+ if args.model_config is not None:
29
+ command += f" --model-config {args.model_config}"
30
+ else:
31
+ command += f" --tversky-alpha {args.tversky_alpha} --channels {args.channels} --strides {args.strides} --dim-in {args.dim_in} --res-units {args.res_units}"
32
+
33
+ # If Model Weights are provided, use them to initialize the model
34
+ if args.model_weights is not None and args.model_config is not None:
35
+ command += f" --model-weights {args.model_weights}"
36
+
37
+ create_shellsubmit(
38
+ job_name = args.job_name,
39
+ output_file = 'trainer.log',
40
+ shell_name = 'train_octopi.sh',
41
+ conda_path = args.conda_env,
42
+ command = command,
43
+ num_gpus = 1,
44
+ gpu_constraint = args.gpu_constraint
45
+ )
46
+
47
+ def train_model_slurm():
48
+ """
49
+ Create a SLURM script for training 3D CNN models
50
+ """
51
+ parser_description = "Create a SLURM script for training 3D CNN models"
52
+ args = run_train.train_model_parser(parser_description, add_slurm=True)
53
+ create_train_script(args)
54
+
55
+ def create_model_explore_script(args):
56
+ """
57
+ Create a SLURM script for running bayesian optimization on 3D CNN models
58
+ """
59
+ strconfigs = ""
60
+ for config in args.config:
61
+ strconfigs += f"--config {config}"
62
+
63
+ command = f"""
64
+ octopi model-explore \\
65
+ --model-type {args.model_type} --model-save-path {args.model_save_path} \\
66
+ --voxel-size {args.voxel_size} --tomo-alg {args.tomo_alg} --Nclass {args.Nclass} \\
67
+ --val-interval {args.val_interval} --num-epochs {args.num_epochs} --num-trials {args.num_trials} \\
68
+ --best-metric {args.best_metric} --mlflow-experiment-name {args.mlflow_experiment_name} \\
69
+ --target-name {args.target_name} --target-session-id {args.target_session_id} --target-user-id {args.target_user_id} \\
70
+ {strconfigs}
71
+ """
72
+
73
+ create_shellsubmit(
74
+ job_name = args.job_name,
75
+ output_file = 'optuna.log',
76
+ shell_name = 'model_explore.sh',
77
+ conda_path = args.conda_env,
78
+ command = command,
79
+ num_gpus = 1,
80
+ gpu_constraint = args.gpu_constraint
81
+ )
82
+
83
+ def model_explore_slurm():
84
+ """
85
+ Create a SLURM script for running bayesian optimization on 3D CNN models
86
+ """
87
+ parser_description = "Create a SLURM script for running bayesian optimization on 3D CNN models"
88
+ args = run_optuna.optuna_parser(parser_description, add_slurm=True)
89
+ create_model_explore_script(args)
90
+
91
+ def create_inference_script(args):
92
+ """
93
+ Create a SLURM script for running inference on 3D CNN models
94
+ """
95
+
96
+ if len(args.config.split(',')) > 1:
97
+
98
+ create_multiconfig_shellsubmit(
99
+ job_name = args.job_name,
100
+ output_file = 'predict.log',
101
+ shell_name = 'segment.sh',
102
+ conda_path = args.conda_env,
103
+ base_inputs = args.base_inputs,
104
+ config_inputs = args.config_inputs,
105
+ command = args.command,
106
+ num_gpus = args.num_gpus,
107
+ gpu_constraint = args.gpu_constraint
108
+ )
109
+ else:
110
+
111
+ command = f"""
112
+ octopi inference \\
113
+ --config {args.config} \\
114
+ --seg-info {",".join(args.seg_info)} \\
115
+ --model-weights {args.model_weights} \\
116
+ --dim-in {args.dim_in} --res-units {args.res_units} \\
117
+ --model-type {args.model_type} --channels {",".join(map(str, args.channels))} --strides {",".join(map(str, args.strides))} \\
118
+ --voxel-size {args.voxel_size} --tomo-alg {args.tomo_alg} --Nclass {args.Nclass}
119
+ """
120
+
121
+ create_shellsubmit(
122
+ job_name = args.job_name,
123
+ output_file = 'predict.log',
124
+ shell_name = 'segment.sh',
125
+ conda_path = args.conda_env,
126
+ command = command,
127
+ num_gpus = 1,
128
+ gpu_constraint = args.gpu_constraint
129
+ )
130
+
131
+ def inference_slurm():
132
+ """
133
+ Create a SLURM script for running segmentation predictions with a specified model and configuration on CryoET Tomograms.
134
+ """
135
+ parser_description = "Create a SLURM script for running segmentation predictions with a specified model and configuration on CryoET Tomograms."
136
+ args = run_segment_predict.inference_parser(parser_description, add_slurm=True)
137
+ create_inference_script(args)
138
+
139
+ def create_localize_script(args):
140
+ """"
141
+ Create a SLURM script for running localization on predicted segmentation masks
142
+ """
143
+ if len(args.config.split(',')) > 1:
144
+
145
+ create_multiconfig_shellsubmit(
146
+ job_name = args.job_name,
147
+ output_file = args.output,
148
+ shell_name = args.output_script,
149
+ conda_path = args.conda_env,
150
+ base_inputs = args.base_inputs,
151
+ config_inputs = args.config_inputs,
152
+ command = args.command
153
+ )
154
+ else:
155
+
156
+ command = f"""
157
+ octopi localize \\
158
+ --config {args.config} \\
159
+ --voxel-size {args.voxel_size} --pick-session-id {args.pick_session_id} --pick-user-id {args.pick_user_id} \\
160
+ --method {args.method} --seg-info {",".join(args.seg_info)} \\
161
+ """
162
+ if args.pick_objects is not None:
163
+ command += f" --pick-objects {args.pick_objects}"
164
+
165
+ create_shellsubmit(
166
+ job_name = args.job_name,
167
+ output_file = 'localize.log',
168
+ shell_name = 'localize.sh',
169
+ conda_path = args.conda_env,
170
+ command = command,
171
+ num_gpus = 0
172
+ )
173
+
174
+ def localize_slurm():
175
+ """
176
+ Create a SLURM script for running localization on predicted segmentation masks
177
+ """
178
+ parser_description = "Create a SLURM script for localization on predicted segmentation masks"
179
+ args = run_localize.localize_parser(parser_description, add_slurm=True)
180
+ create_localize_script(args)
181
+
182
+ def create_extract_mb_picks_script(args):
183
+ pass
184
+
185
+ def extract_mb_picks_slurm():
186
+ pass
187
+
188
+
189
+ def create_import_mrc_script(args):
190
+ """
191
+ Create a SLURM script for importing mrc volumes and potentialy downsampling
192
+ """
193
+ command = f"""
194
+ octopi import-mrc-volumes \\
195
+ --mrcs-path {args.mrcs_path} \\
196
+ --config {args.config} --target-tomo-type {args.target_tomo_type} \\
197
+ --input-voxel-size {args.input_voxel_size} --output-voxel-size {args.output_voxel_size}
198
+ """
199
+
200
+ create_shellsubmit(
201
+ job_name = args.job_name,
202
+ output_file = 'importer.log',
203
+ shell_name = 'mrc_importer.sh',
204
+ conda_path = args.conda_env,
205
+ command = command
206
+ )
207
+
208
+ def import_mrc_slurm():
209
+ """
210
+ Create a SLURM script for importing mrc volumes and potentialy downsampling
211
+ """
212
+ parser_description = "Create a SLURM script for importing mrc volumes and potentialy downsampling"
213
+ args = cli_mrcs_parser(parser_description, add_slurm=True)
214
+ create_import_mrc_script(args)
215
+
216
+
217
+ def create_download_dataportal_script(args):
218
+ """
219
+ Create a SLURM script for downloading tomograms from the Dataportal
220
+ """
221
+ command = f"""
222
+ octopi download-dataportal \\
223
+ --config {args.config} --datasetID {args.datasetID} \\
224
+ --overlay-path {args.overlay_path}
225
+ --dataportal-name {args.dataportal_name} --target-tomo-type {args.target_tomo_type} \\
226
+ --input-voxel-size {args.input_voxel_size} --output-voxel-size {args.output_voxel_size}
227
+ """
228
+
229
+ create_shellsubmit(
230
+ job_name = args.job_name,
231
+ output_file = 'importer.log',
232
+ shell_name = 'dataportal_importer.sh',
233
+ conda_path = args.conda_env,
234
+ command = command
235
+ )
236
+
237
+ def download_dataportal_slurm():
238
+ """
239
+ Create a SLURM script for downloading tomograms from the Dataportal
240
+ """
241
+ parser_description = "Create a SLURM script for downloading tomograms from the Dataportal"
242
+ args = cli_dataportal_parser(parser_description, add_slurm=True)
243
+ create_download_dataportal_script(args)