octopi 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. octopi/__init__.py +7 -0
  2. octopi/datasets/__init__.py +0 -0
  3. octopi/datasets/augment.py +83 -0
  4. octopi/datasets/cached_datset.py +113 -0
  5. octopi/datasets/dataset.py +19 -0
  6. octopi/datasets/generators.py +458 -0
  7. octopi/datasets/io.py +200 -0
  8. octopi/datasets/mixup.py +49 -0
  9. octopi/datasets/multi_config_generator.py +252 -0
  10. octopi/entry_points/__init__.py +0 -0
  11. octopi/entry_points/common.py +119 -0
  12. octopi/entry_points/create_slurm_submission.py +251 -0
  13. octopi/entry_points/groups.py +152 -0
  14. octopi/entry_points/run_create_targets.py +234 -0
  15. octopi/entry_points/run_evaluate.py +99 -0
  16. octopi/entry_points/run_extract_mb_picks.py +191 -0
  17. octopi/entry_points/run_extract_midpoint.py +143 -0
  18. octopi/entry_points/run_localize.py +176 -0
  19. octopi/entry_points/run_optuna.py +161 -0
  20. octopi/entry_points/run_segment.py +154 -0
  21. octopi/entry_points/run_train.py +189 -0
  22. octopi/extract/__init__.py +0 -0
  23. octopi/extract/localize.py +217 -0
  24. octopi/extract/membranebound_extract.py +263 -0
  25. octopi/extract/midpoint_extract.py +193 -0
  26. octopi/main.py +33 -0
  27. octopi/models/AttentionUnet.py +56 -0
  28. octopi/models/MedNeXt.py +111 -0
  29. octopi/models/ModelTemplate.py +36 -0
  30. octopi/models/SegResNet.py +92 -0
  31. octopi/models/Unet.py +59 -0
  32. octopi/models/UnetPlusPlus.py +47 -0
  33. octopi/models/__init__.py +0 -0
  34. octopi/models/common.py +72 -0
  35. octopi/processing/__init__.py +0 -0
  36. octopi/processing/create_targets_from_picks.py +224 -0
  37. octopi/processing/downloader.py +138 -0
  38. octopi/processing/downsample.py +125 -0
  39. octopi/processing/evaluate.py +302 -0
  40. octopi/processing/importers.py +116 -0
  41. octopi/processing/segmentation_from_picks.py +167 -0
  42. octopi/pytorch/__init__.py +0 -0
  43. octopi/pytorch/hyper_search.py +244 -0
  44. octopi/pytorch/model_search_submitter.py +291 -0
  45. octopi/pytorch/segmentation.py +363 -0
  46. octopi/pytorch/segmentation_multigpu.py +162 -0
  47. octopi/pytorch/trainer.py +465 -0
  48. octopi/pytorch_lightning/__init__.py +0 -0
  49. octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
  50. octopi/pytorch_lightning/train_pl.py +244 -0
  51. octopi/utils/__init__.py +0 -0
  52. octopi/utils/config.py +57 -0
  53. octopi/utils/io.py +215 -0
  54. octopi/utils/losses.py +86 -0
  55. octopi/utils/parsers.py +162 -0
  56. octopi/utils/progress.py +78 -0
  57. octopi/utils/stopping_criteria.py +143 -0
  58. octopi/utils/submit_slurm.py +95 -0
  59. octopi/utils/visualization_tools.py +290 -0
  60. octopi/workflows.py +262 -0
  61. octopi-1.4.0.dist-info/METADATA +119 -0
  62. octopi-1.4.0.dist-info/RECORD +65 -0
  63. octopi-1.4.0.dist-info/WHEEL +4 -0
  64. octopi-1.4.0.dist-info/entry_points.txt +3 -0
  65. octopi-1.4.0.dist-info/licenses/LICENSE +41 -0
octopi/datasets/io.py ADDED
@@ -0,0 +1,200 @@
1
+ """
2
+ Data loading, processing, and dataset operations for the datasets module.
3
+ """
4
+
5
+ from monai.data import DataLoader, CacheDataset
6
+ from monai.transforms import (
7
+ Compose,
8
+ NormalizeIntensityd,
9
+ EnsureChannelFirstd,
10
+ )
11
+ from sklearn.model_selection import train_test_split
12
+ from collections import defaultdict
13
+ from copick_utils.io import readers
14
+ import copick, torch, os, random
15
+ from typing import List
16
+ from tqdm import tqdm
17
+
18
+
19
+ def load_training_data(root,
20
+ runIDs: List[str],
21
+ voxel_spacing: float,
22
+ tomo_algorithm: str,
23
+ segmenation_name: str,
24
+ segmentation_session_id: str = None,
25
+ segmentation_user_id: str = None,
26
+ progress_update: bool = True):
27
+ """
28
+ Load training data from CoPick runs.
29
+ """
30
+ data_dicts = []
31
+ # Use tqdm for progress tracking only if progress_update is True
32
+ iterable = tqdm(runIDs, desc="Loading Training Data") if progress_update else runIDs
33
+ for runID in iterable:
34
+ run = root.get_run(str(runID))
35
+ tomogram = readers.tomogram(run, voxel_spacing, tomo_algorithm)
36
+ segmentation = readers.segmentation(run,
37
+ voxel_spacing,
38
+ segmenation_name,
39
+ segmentation_session_id,
40
+ segmentation_user_id)
41
+ data_dicts.append({"image": tomogram, "label": segmentation})
42
+
43
+ return data_dicts
44
+
45
+
46
+ def load_predict_data(root,
47
+ runIDs: List[str],
48
+ voxel_spacing: float,
49
+ tomo_algorithm: str):
50
+ """
51
+ Load prediction data from CoPick runs.
52
+ """
53
+ data_dicts = []
54
+ for runID in tqdm(runIDs):
55
+ run = root.get_run(str(runID))
56
+ tomogram = readers.tomogram(run, voxel_spacing, tomo_algorithm)
57
+ data_dicts.append({"image": tomogram})
58
+
59
+ return data_dicts
60
+
61
+
62
+ def create_predict_dataloader(
63
+ root,
64
+ voxel_spacing: float,
65
+ tomo_algorithm: str,
66
+ runIDs: str = None,
67
+ ):
68
+ """
69
+ Create a dataloader for prediction data.
70
+ """
71
+ # define pre transforms
72
+ pre_transforms = Compose(
73
+ [ EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
74
+ NormalizeIntensityd(keys=["image"]),
75
+ ])
76
+
77
+ # Split trainRunIDs, validateRunIDs, testRunIDs
78
+ if runIDs is None:
79
+ runIDs = [run.name for run in root.runs]
80
+ test_files = load_predict_data(root, runIDs, voxel_spacing, tomo_algorithm)
81
+
82
+ bs = min(len(test_files), os.cpu_count() or 4)
83
+ test_ds = CacheDataset(data=test_files, transform=pre_transforms)
84
+ test_loader = DataLoader(test_ds,
85
+ batch_size=bs,
86
+ shuffle=False,
87
+ num_workers=bs,
88
+ pin_memory=torch.cuda.is_available())
89
+ return test_loader, test_ds
90
+
91
+
92
+ def adjust_to_multiple(value, multiple = 16):
93
+ """
94
+ Adjust a value to be a multiple of the specified number.
95
+ """
96
+ return int((value // multiple) * multiple)
97
+
98
+
99
+ def get_input_dimensions(dataset, crop_size: int):
100
+ """
101
+ Get input dimensions for the dataset.
102
+ """
103
+ nx = dataset[0]['image'].shape[1]
104
+ if crop_size > nx:
105
+ first_dim = adjust_to_multiple(nx/2)
106
+ return first_dim, crop_size, crop_size
107
+ else:
108
+ return crop_size, crop_size, crop_size
109
+
110
+
111
+ def get_num_classes(copick_config_path: str):
112
+ """
113
+ Get the number of classes from a CoPick configuration.
114
+ """
115
+ root = copick.from_file(copick_config_path)
116
+ return len(root.pickable_objects) + 1
117
+
118
+
119
+ def split_multiclass_dataset(runIDs,
120
+ train_ratio: float = 0.7,
121
+ val_ratio: float = 0.15,
122
+ test_ratio: float = 0.15,
123
+ return_test_dataset: bool = True,
124
+ random_state: int = 42):
125
+ """
126
+ Splits a given dataset into three subsets: training, validation, and testing. If the dataset
127
+ has categories (as tuples), splits are balanced across all categories. If the dataset is a 1D
128
+ list, it is split without categorization.
129
+
130
+ Parameters:
131
+ - runIDs: A list of items to split. It can be a 1D list or a list of tuples (category, value).
132
+ - train_ratio: Proportion of the dataset for training.
133
+ - val_ratio: Proportion of the dataset for validation.
134
+ - test_ratio: Proportion of the dataset for testing.
135
+ - return_test_dataset: Whether to return the test dataset.
136
+ - random_state: Random state for reproducibility.
137
+
138
+ Returns:
139
+ - trainRunIDs: Training subset.
140
+ - valRunIDs: Validation subset.
141
+ - testRunIDs: Testing subset (if return_test_dataset is True, otherwise None).
142
+ """
143
+
144
+ # Ensure the ratios add up to 1
145
+ assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must add up to 1.0"
146
+
147
+ # Check if the dataset has categories
148
+ if isinstance(runIDs[0], tuple) and len(runIDs[0]) == 2:
149
+ # Group by category
150
+ grouped = defaultdict(list)
151
+ for item in runIDs:
152
+ grouped[item[0]].append(item)
153
+
154
+ # Split each category
155
+ trainRunIDs, valRunIDs, testRunIDs = [], [], []
156
+ for category, items in grouped.items():
157
+ # Shuffle for randomness
158
+ random.shuffle(items)
159
+ # Split into train and remaining
160
+ train_items, remaining = train_test_split(items, test_size=(1 - train_ratio), random_state=random_state)
161
+ trainRunIDs.extend(train_items)
162
+
163
+ if return_test_dataset:
164
+ # Split remaining into validation and test
165
+ val_items, test_items = train_test_split(
166
+ remaining,
167
+ test_size=(test_ratio / (val_ratio + test_ratio)),
168
+ random_state=random_state,
169
+ )
170
+ valRunIDs.extend(val_items)
171
+ testRunIDs.extend(test_items)
172
+ else:
173
+ valRunIDs.extend(remaining)
174
+ testRunIDs = []
175
+ else:
176
+ # If no categories, split as a 1D list
177
+ trainRunIDs, remaining = train_test_split(runIDs, test_size=(1 - train_ratio), random_state=random_state)
178
+ if return_test_dataset:
179
+ valRunIDs, testRunIDs = train_test_split(
180
+ remaining,
181
+ test_size=(test_ratio / (val_ratio + test_ratio)),
182
+ random_state=random_state,
183
+ )
184
+ else:
185
+ valRunIDs = remaining
186
+ testRunIDs = []
187
+
188
+ return trainRunIDs, valRunIDs, testRunIDs
189
+
190
+
191
+ def load_copick_config(path: str):
192
+ """
193
+ Load a CoPick configuration from file.
194
+ """
195
+ if os.path.isfile(path):
196
+ root = copick.from_file(path)
197
+ else:
198
+ raise FileNotFoundError(f"Copick Config Path does not exist: {path}")
199
+
200
+ return root
@@ -0,0 +1,49 @@
1
+ from monai.transforms import Transform
2
+ from torch.distributions import Beta
3
+ from torch import nn
4
+ import numpy as np
5
+ import torch
6
+
7
+ class MixupTransformd(Transform):
8
+ """
9
+ A dictionary-based wrapper for Mixup augmentation that applies to batches.
10
+ This needs to be applied after batching, typically directly in the training loop.
11
+ """
12
+ def __init__(self, keys=("image", "label"), mix_beta=0.2, mixadd=False, prob=0.5):
13
+ self.keys = keys
14
+ self.mixup = Mixup(mix_beta=mix_beta, mixadd=mixadd)
15
+ self.prob = prob
16
+
17
+ def __call__(self, data):
18
+ d = dict(data)
19
+ if np.random.random() < self.prob: # Apply with probability
20
+ d[self.keys[0]], d[self.keys[1]] = self.mixup(d[self.keys[0]], d[self.keys[1]])
21
+ return d
22
+
23
+ class Mixup(nn.Module):
24
+ def __init__(self, mix_beta, mixadd=False):
25
+
26
+ super(Mixup, self).__init__()
27
+ self.beta_distribution = Beta(mix_beta, mix_beta)
28
+ self.mixadd = mixadd
29
+
30
+ def forward(self, X, Y, Z=None):
31
+
32
+ bs = X.shape[0]
33
+ n_dims = len(X.shape)
34
+ perm = torch.randperm(bs)
35
+ coeffs = self.beta_distribution.rsample(torch.Size((bs,))).to(X.device)
36
+ X_coeffs = coeffs.view((-1,) + (1,)*(X.ndim-1))
37
+ Y_coeffs = coeffs.view((-1,) + (1,)*(Y.ndim-1))
38
+
39
+ X = X_coeffs * X + (1-X_coeffs) * X[perm]
40
+
41
+ if self.mixadd:
42
+ Y = (Y + Y[perm]).clip(0, 1)
43
+ else:
44
+ Y = Y_coeffs * Y + (1 - Y_coeffs) * Y[perm]
45
+
46
+ if Z:
47
+ return X, Y, Z
48
+
49
+ return X, Y
@@ -0,0 +1,252 @@
1
+ from octopi.datasets import dataset, augment, cached_datset
2
+ from octopi.datasets.generators import TrainLoaderManager
3
+ from monai.data import DataLoader, SmartCacheDataset, CacheDataset, Dataset
4
+ from octopi.datasets import io
5
+ import multiprocess as mp
6
+ from typing import List
7
+ from tqdm import tqdm
8
+ import torch, gc
9
+
10
+ class MultiConfigTrainLoaderManager(TrainLoaderManager):
11
+
12
+ def __init__(self,
13
+ configs: dict, # Dictionary of session names and config paths
14
+ target_name: str,
15
+ target_session_id: str = None,
16
+ target_user_id: str = None,
17
+ voxel_size: float = 10,
18
+ tomo_algorithm: List[str] = ['wbp'],
19
+ tomo_batch_size: int = 15
20
+ ):
21
+ """
22
+ Initialize MultiConfigTrainLoaderManager with multiple configs.
23
+
24
+ Args:
25
+ configs (list): List of config file paths.
26
+ Other arguments are inherited from TrainLoaderManager.
27
+ """
28
+
29
+ # Initialize shared attributes manually (skip super().__init__ to avoid invalid config handling)
30
+ self.config = configs
31
+ self.roots = {name: io.load_copick_config(path) for name, path in configs.items()}
32
+
33
+ # Target and algorithm parameters
34
+ self.target_name = target_name
35
+ self.target_session_id = target_session_id
36
+ self.target_user_id = target_user_id
37
+ self.voxel_size = voxel_size
38
+ self.tomo_algorithm = tomo_algorithm
39
+
40
+ # Data management parameters
41
+ self.tomo_batch_size = tomo_batch_size
42
+ self.reload_training_dataset = True
43
+ self.reload_validation_dataset = True
44
+ self.val_loader = None
45
+ self.train_loader = None
46
+
47
+ # Initialize Run IDs placeholder
48
+ self.myRunIDs = {}
49
+
50
+ # Initialize the input dimensions
51
+ self.nx = None
52
+ self.ny = None
53
+ self.nz = None
54
+
55
+ def get_available_runIDs(self):
56
+ """
57
+ Identify and return a combined list of run IDs with available segmentations
58
+ across all configured CoPick projects.
59
+
60
+ Returns:
61
+ List of tuples: [(session_name, run_name), ...]
62
+ """
63
+ available_runIDs = []
64
+ for name, root in self.roots.items():
65
+ runIDs = [run.name for run in root.runs]
66
+ for run in runIDs:
67
+ run = root.get_run(run)
68
+ seg = run.get_segmentations(
69
+ name=self.target_name,
70
+ session_id=self.target_session_id,
71
+ user_id=self.target_user_id,
72
+ voxel_size=float(self.voxel_size)
73
+ )
74
+ if len(seg) > 0:
75
+ available_runIDs.append((name, run.name)) # Include session name for disambiguation
76
+
77
+ # If No Segmentations are Found, Inform the User
78
+ if len(available_runIDs) == 0:
79
+ print(
80
+ f"[Error] No segmentations found for the target query:\n"
81
+ f"TargetName: {self.target_name}, UserID: {self.target_user_id}, "
82
+ f"SessionID: {self.target_session_id}\n"
83
+ f"Please check the target name, user ID, and session ID.\n"
84
+ )
85
+ exit()
86
+
87
+ return available_runIDs
88
+
89
+ def get_data_splits(self,
90
+ trainRunIDs: str = None,
91
+ validateRunIDs: str = None,
92
+ train_ratio: float = 0.8,
93
+ val_ratio: float = 0.1,
94
+ test_ratio: float = 0.1,
95
+ create_test_dataset: bool = True):
96
+ """
97
+ Override to handle run IDs as (session_name, run_name) tuples.
98
+ """
99
+ # Use the get_available_runIDs method to handle multiple projects
100
+ runIDs = self.get_available_runIDs()
101
+ return super().get_data_splits(trainRunIDs = runIDs,
102
+ train_ratio = train_ratio,
103
+ val_ratio = val_ratio,
104
+ test_ratio = test_ratio,
105
+ create_test_dataset = create_test_dataset)
106
+
107
+ def _initialize_train_iterators(self):
108
+ """
109
+ Initialize the training data iterators with multi-config support.
110
+ """
111
+ self.padded_train_list = self._get_padded_list(self.myRunIDs['train'], self.train_batch_size)
112
+ self.train_data_iter = iter(self._get_data_batches(self.padded_train_list, self.train_batch_size))
113
+
114
+ def _initialize_val_iterators(self):
115
+ """
116
+ Initialize the validation data iterators with multi-config support.
117
+ """
118
+ self.padded_val_list = self._get_padded_list(self.myRunIDs['validate'], self.val_batch_size)
119
+ self.val_data_iter = iter(self._get_data_batches(self.padded_val_list, self.val_batch_size))
120
+
121
+ def _load_data(self, runIDs):
122
+ """
123
+ Load data from multiple CoPick projects for given run IDs.
124
+
125
+ Args:
126
+ runIDs (list): List of (session_name, run_name) tuples.
127
+
128
+ Returns:
129
+ List: Combined dataset for the specified run IDs.
130
+ """
131
+
132
+ data = []
133
+ for session_name, run_name in tqdm(runIDs):
134
+ root = self.roots[session_name]
135
+ data.extend(io.load_training_data(
136
+ root, [run_name], self.voxel_size, self.tomo_algorithm,
137
+ self.target_name, self.target_session_id, self.target_user_id,
138
+ progress_update=False ))
139
+ self._check_max_label_value(data)
140
+ return data
141
+
142
+ def create_train_dataloaders(self, *args, **kwargs):
143
+ """
144
+ Override data loading to fetch from multiple projects.
145
+ """
146
+ my_crop_size = kwargs.get("crop_size", 96)
147
+ my_num_samples = kwargs.get("num_samples", 128)
148
+
149
+ # If reloads are disabled and loaders already exist, reuse them
150
+ if self.reload_frequency < 0 and (self.train_loader is not None) and (self.val_loader is not None):
151
+ return self.train_loader, self.val_loader
152
+
153
+ # Estimate Max Number of Threads with mp.cpu_count
154
+ n_procs = min(mp.cpu_count(), 4)
155
+
156
+ if self.train_loader is None:
157
+ # Fetch the next batch of run IDs
158
+ trainRunIDs = self._extract_run_ids('train_data_iter', self._initialize_train_iterators)
159
+ train_files = self._load_data(trainRunIDs)
160
+
161
+ # # Create the cached dataset with non-random transforms
162
+ train_ds = SmartCacheDataset(data=train_files, transform=augment.get_transforms(), cache_rate=0.5)
163
+
164
+ # # Delete the training files to free memory
165
+ train_files = None
166
+ gc.collect()
167
+
168
+ # Create the cached dataset with non-random transforms
169
+ # train_ds = cached_datset.MultiConfigCacheDataset(
170
+ # self, trainRunIDs, transform=augment.get_transforms(), cache_rate=1.0
171
+ # )
172
+
173
+ # I need to read (nx,ny,nz) and scale the crop size to make sure it isnt larger than nx.
174
+ if self.nx is None: (self.nx,self.ny,self.nz) = train_ds[0]['image'].shape[1:]
175
+ self.input_dim = io.get_input_dimensions(train_ds, my_crop_size)
176
+
177
+ # Wrap the cached dataset to apply random transforms during iteration
178
+ self.dynamic_train_dataset = dataset.DynamicDataset(
179
+ data=train_ds,
180
+ transform=augment.get_random_transforms(self.input_dim, my_num_samples, self.Nclasses)
181
+ )
182
+
183
+ self.train_loader = DataLoader(
184
+ self.dynamic_train_dataset,
185
+ batch_size=1,
186
+ shuffle=True,
187
+ num_workers=n_procs,
188
+ pin_memory=torch.cuda.is_available(),
189
+ )
190
+
191
+ else:
192
+ # Fetch the next batch of run IDs
193
+ trainRunIDs = self._extract_run_ids('train_data_iter', self._initialize_train_iterators)
194
+ train_files = self._load_data(trainRunIDs)
195
+ train_ds = CacheDataset(data=train_files, transform=augment.get_transforms(), cache_rate=1.0)
196
+ self.dynamic_train_dataset.update_data(train_ds)
197
+
198
+ # We Only Need to Reload the Validation Dataset if the Total Number of Runs is larger than
199
+ # the tomo batch size
200
+ if self.val_loader is None:
201
+
202
+ validateRunIDs = self._extract_run_ids('val_data_iter', self._initialize_val_iterators)
203
+ val_files = self._load_data(validateRunIDs)
204
+
205
+ # # Create validation dataset
206
+ val_ds = SmartCacheDataset(data=val_files, transform=augment.get_transforms(), cache_rate=1.0)
207
+
208
+ # # Delete the validation files to free memory
209
+ val_files = None
210
+ gc.collect()
211
+
212
+ # Create the cached dataset with non-random transforms
213
+ # val_ds = cached_datset.MultiConfigCacheDataset(
214
+ # self, validateRunIDs, transform=augment.get_transforms(), cache_rate=1.0
215
+ # )
216
+
217
+ # # I need to read (nx,ny,nz) and scale the crop size to make sure it isnt larger than nx.
218
+ # if self.nx is None:
219
+ # (self.nx,self.ny,self.nz) = val_ds[0]['image'].shape[1:]
220
+
221
+ # if crop_size > self.nx: self.input_dim = (self.nx, crop_size, crop_size)
222
+ # else: self.input_dim = (crop_size, crop_size, crop_size)
223
+
224
+ # Wrap the cached dataset to apply random transforms during iteration
225
+ self.dynamic_validation_dataset = dataset.DynamicDataset( data=val_ds )
226
+
227
+ # Create validation DataLoader
228
+ self.val_loader = DataLoader(
229
+ self.dynamic_validation_dataset,
230
+ batch_size=1,
231
+ num_workers=n_procs,
232
+ pin_memory=torch.cuda.is_available(),
233
+ shuffle=False, # Ensure the data order remains consistent
234
+ )
235
+ else:
236
+ validateRunIDs = self._extract_run_ids('val_data_iter', self._initialize_val_iterators)
237
+ val_files = self._load_data(validateRunIDs)
238
+
239
+ val_ds = CacheDataset(data=val_files, transform=augment.get_transforms(), cache_rate=1.0)
240
+ self.dynamic_validation_dataset.update_data(val_ds)
241
+
242
+ return self.train_loader, self.val_loader
243
+
244
+
245
+ def tmp_return_datasets(self):
246
+ trainRunIDs = self._extract_run_ids('train_data_iter', self._initialize_train_iterators)
247
+ train_files = self._load_data(trainRunIDs)
248
+
249
+ validateRunIDs = self._extract_run_ids('val_data_iter', self._initialize_val_iterators)
250
+ val_files = self._load_data(validateRunIDs)
251
+
252
+ return train_files, val_files
File without changes
@@ -0,0 +1,119 @@
1
+ from octopi.utils import parsers
2
+ import rich_click as click
3
+ from typing import List, Tuple
4
+
5
+ def model_parameters(octopi: bool = False):
6
+ """Decorator for adding model parameters"""
7
+ def decorator(f):
8
+ f = click.option("-dim", "--dim-in", type=int, default=96,
9
+ help="Input dimension for the UNet model")(f)
10
+ f = click.option("--res-units", type=int, default=1,
11
+ help="Number of residual units in the UNet")(f)
12
+ f = click.option("-s", "--strides", type=str, default='2,2,1',
13
+ callback=lambda ctx, param, value: parsers.parse_int_list(value) if value else value,
14
+ help="List of stride sizes")(f)
15
+ f = click.option("-ch", "--channels", type=str, default='32,64,96,96',
16
+ callback=lambda ctx, param, value: parsers.parse_int_list(value) if value else value,
17
+ help="List of channel sizes")(f)
18
+ return f
19
+ return decorator
20
+
21
+ def inference_model_parameters():
22
+ """Decorator for adding inference model parameters"""
23
+ def decorator(f):
24
+ f = click.option("-mw", "--model-weights", type=click.Path(exists=True), required=True,
25
+ help="Path to the model weights file")(f)
26
+ f = click.option("-mc", "--model-config", type=click.Path(exists=True), required=True,
27
+ help="Path to the model configuration file")(f)
28
+ return f
29
+ return decorator
30
+
31
+ def train_parameters(octopi: bool = False):
32
+ """Decorator for adding training parameters"""
33
+ def decorator(f):
34
+ if octopi:
35
+ f = click.option("-nt", "--num-trials", type=int, default=10,
36
+ help="Number of trials for architecture search")(f)
37
+ else:
38
+ f = click.option("-o", "--model-save-path", type=click.Path(), default='results',
39
+ help="Path to model save directory")(f)
40
+ f = click.option("--tversky-alpha", type=float, default=0.3,
41
+ help="Alpha parameter for the Tversky loss")(f)
42
+ f = click.option("-lr", "--lr", type=float, default=1e-3,
43
+ help="Learning rate for the optimizer")(f)
44
+ f = click.option('-ncrops', "--num-tomo-crops", type=int, default=16,
45
+ help="Number of tomogram crops to use per patch")(f)
46
+
47
+ f = click.option("--best-metric", type=str, default='avg_f1',
48
+ help="Metric to Monitor for Determining Best Model. To track fBetaN, use fBetaN with N as the beta-value.")(f)
49
+ f = click.option('-ntomos', "--tomo-batch-size", type=int, default=15,
50
+ help="Number of tomograms to load per epoch for training")(f)
51
+ f = click.option("--val-interval", type=int, default=10,
52
+ help="Interval for validation metric calculations")(f)
53
+ f = click.option('-nepochs', "--num-epochs", type=int, default=1000,
54
+ help="Number of training epochs")(f)
55
+ return f
56
+ return decorator
57
+
58
+ def config_parameters(single_config: bool):
59
+ """Decorator for adding config parameters"""
60
+ def decorator(f):
61
+ f = click.option("-vs", "--voxel-size", type=float, default=10,
62
+ help="Voxel size of tomograms used")(f)
63
+ if single_config:
64
+ f = click.option("-c", "--config", type=click.Path(exists=True), required=True,
65
+ help="Path to the configuration file")(f)
66
+ else:
67
+ f = click.option("-c", "--config", type=str, multiple=True, required=True,
68
+ help="Specify a single configuration path (/path/to/config.json) or multiple entries in the format session_name,/path/to/config.json. Use multiple --config entries for multiple sessions.")(f)
69
+ return f
70
+ return decorator
71
+
72
+ def inference_parameters():
73
+ """Decorator for adding inference parameters"""
74
+ def decorator(f):
75
+ f = click.option('-runs', "--run-ids", type=str, default=None,
76
+ callback=lambda ctx, param, value: parsers.parse_list(value) if value else None,
77
+ help="List of run IDs for prediction, e.g., run1,run2 or [run1,run2]. If not provided, all available runs will be processed.")(f)
78
+ f = click.option('-ntomos', "--tomo-batch-size", type=int, default=25,
79
+ help="Batch size for tomogram processing")(f)
80
+ f = click.option('-seginfo', "--seg-info", type=str, default='predict,octopi,1',
81
+ callback=lambda ctx, param, value: parsers.parse_target(value) if value else value,
82
+ help='Information Query to save Segmentation predictions under (e.g., "name" or "name,user_id,session_id" - Default UserID is octopi and SessionID is 1')(f)
83
+ f = click.option('-alg', "--tomo-alg", type=str, default='wbp',
84
+ help="Tomogram algorithm used for produces segmentation prediction masks")(f)
85
+ return f
86
+ return decorator
87
+
88
+ def localize_parameters():
89
+ """Decorator for adding localization parameters"""
90
+ def decorator(f):
91
+ f = click.option('-seginfo', "--seg-info", type=str, required=True,
92
+ help="Segmentation info")(f)
93
+ f = click.option("--pick-objects", type=str, required=True,
94
+ help="Pick objects")(f)
95
+ f = click.option("--pick-session-id", type=str, default="1",
96
+ help="Pick session ID")(f)
97
+ f = click.option('-m', "--method", type=str, default='watershed',
98
+ help="Localization method")(f)
99
+ f = click.option('-vs', "--voxel-size", type=int, default=10,
100
+ help="Voxel size")(f)
101
+ return f
102
+ return decorator
103
+
104
+ def slurm_parameters(base_job_name: str, gpus: int = 1):
105
+ """Decorator for adding SLURM parameters"""
106
+ def decorator(f):
107
+ if gpus > 1:
108
+ f = click.option("--num-gpus", type=int, default=1,
109
+ help="Number of GPUs to use")(f)
110
+ if gpus > 0:
111
+ f = click.option("--gpu-constraint", type=click.Choice(['a6000', 'a100', 'h100', 'h200'], case_sensitive=False),
112
+ default='h100',
113
+ help="GPU constraint")(f)
114
+ f = click.option("--job-name", type=str, default=base_job_name,
115
+ help="Job name for SLURM job")(f)
116
+ f = click.option("--conda-env", type=click.Path(), default='/hpc/projects/group.czii/conda_environments/pyUNET/',
117
+ help="Path to Conda environment")(f)
118
+ return f
119
+ return decorator