octopi 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- octopi/__init__.py +7 -0
- octopi/datasets/__init__.py +0 -0
- octopi/datasets/augment.py +83 -0
- octopi/datasets/cached_datset.py +113 -0
- octopi/datasets/dataset.py +19 -0
- octopi/datasets/generators.py +458 -0
- octopi/datasets/io.py +200 -0
- octopi/datasets/mixup.py +49 -0
- octopi/datasets/multi_config_generator.py +252 -0
- octopi/entry_points/__init__.py +0 -0
- octopi/entry_points/common.py +119 -0
- octopi/entry_points/create_slurm_submission.py +251 -0
- octopi/entry_points/groups.py +152 -0
- octopi/entry_points/run_create_targets.py +234 -0
- octopi/entry_points/run_evaluate.py +99 -0
- octopi/entry_points/run_extract_mb_picks.py +191 -0
- octopi/entry_points/run_extract_midpoint.py +143 -0
- octopi/entry_points/run_localize.py +176 -0
- octopi/entry_points/run_optuna.py +161 -0
- octopi/entry_points/run_segment.py +154 -0
- octopi/entry_points/run_train.py +189 -0
- octopi/extract/__init__.py +0 -0
- octopi/extract/localize.py +217 -0
- octopi/extract/membranebound_extract.py +263 -0
- octopi/extract/midpoint_extract.py +193 -0
- octopi/main.py +33 -0
- octopi/models/AttentionUnet.py +56 -0
- octopi/models/MedNeXt.py +111 -0
- octopi/models/ModelTemplate.py +36 -0
- octopi/models/SegResNet.py +92 -0
- octopi/models/Unet.py +59 -0
- octopi/models/UnetPlusPlus.py +47 -0
- octopi/models/__init__.py +0 -0
- octopi/models/common.py +72 -0
- octopi/processing/__init__.py +0 -0
- octopi/processing/create_targets_from_picks.py +224 -0
- octopi/processing/downloader.py +138 -0
- octopi/processing/downsample.py +125 -0
- octopi/processing/evaluate.py +302 -0
- octopi/processing/importers.py +116 -0
- octopi/processing/segmentation_from_picks.py +167 -0
- octopi/pytorch/__init__.py +0 -0
- octopi/pytorch/hyper_search.py +244 -0
- octopi/pytorch/model_search_submitter.py +291 -0
- octopi/pytorch/segmentation.py +363 -0
- octopi/pytorch/segmentation_multigpu.py +162 -0
- octopi/pytorch/trainer.py +465 -0
- octopi/pytorch_lightning/__init__.py +0 -0
- octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
- octopi/pytorch_lightning/train_pl.py +244 -0
- octopi/utils/__init__.py +0 -0
- octopi/utils/config.py +57 -0
- octopi/utils/io.py +215 -0
- octopi/utils/losses.py +86 -0
- octopi/utils/parsers.py +162 -0
- octopi/utils/progress.py +78 -0
- octopi/utils/stopping_criteria.py +143 -0
- octopi/utils/submit_slurm.py +95 -0
- octopi/utils/visualization_tools.py +290 -0
- octopi/workflows.py +262 -0
- octopi-1.4.0.dist-info/METADATA +119 -0
- octopi-1.4.0.dist-info/RECORD +65 -0
- octopi-1.4.0.dist-info/WHEEL +4 -0
- octopi-1.4.0.dist-info/entry_points.txt +3 -0
- octopi-1.4.0.dist-info/licenses/LICENSE +41 -0
octopi/datasets/io.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Data loading, processing, and dataset operations for the datasets module.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from monai.data import DataLoader, CacheDataset
|
|
6
|
+
from monai.transforms import (
|
|
7
|
+
Compose,
|
|
8
|
+
NormalizeIntensityd,
|
|
9
|
+
EnsureChannelFirstd,
|
|
10
|
+
)
|
|
11
|
+
from sklearn.model_selection import train_test_split
|
|
12
|
+
from collections import defaultdict
|
|
13
|
+
from copick_utils.io import readers
|
|
14
|
+
import copick, torch, os, random
|
|
15
|
+
from typing import List
|
|
16
|
+
from tqdm import tqdm
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def load_training_data(root,
|
|
20
|
+
runIDs: List[str],
|
|
21
|
+
voxel_spacing: float,
|
|
22
|
+
tomo_algorithm: str,
|
|
23
|
+
segmenation_name: str,
|
|
24
|
+
segmentation_session_id: str = None,
|
|
25
|
+
segmentation_user_id: str = None,
|
|
26
|
+
progress_update: bool = True):
|
|
27
|
+
"""
|
|
28
|
+
Load training data from CoPick runs.
|
|
29
|
+
"""
|
|
30
|
+
data_dicts = []
|
|
31
|
+
# Use tqdm for progress tracking only if progress_update is True
|
|
32
|
+
iterable = tqdm(runIDs, desc="Loading Training Data") if progress_update else runIDs
|
|
33
|
+
for runID in iterable:
|
|
34
|
+
run = root.get_run(str(runID))
|
|
35
|
+
tomogram = readers.tomogram(run, voxel_spacing, tomo_algorithm)
|
|
36
|
+
segmentation = readers.segmentation(run,
|
|
37
|
+
voxel_spacing,
|
|
38
|
+
segmenation_name,
|
|
39
|
+
segmentation_session_id,
|
|
40
|
+
segmentation_user_id)
|
|
41
|
+
data_dicts.append({"image": tomogram, "label": segmentation})
|
|
42
|
+
|
|
43
|
+
return data_dicts
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def load_predict_data(root,
|
|
47
|
+
runIDs: List[str],
|
|
48
|
+
voxel_spacing: float,
|
|
49
|
+
tomo_algorithm: str):
|
|
50
|
+
"""
|
|
51
|
+
Load prediction data from CoPick runs.
|
|
52
|
+
"""
|
|
53
|
+
data_dicts = []
|
|
54
|
+
for runID in tqdm(runIDs):
|
|
55
|
+
run = root.get_run(str(runID))
|
|
56
|
+
tomogram = readers.tomogram(run, voxel_spacing, tomo_algorithm)
|
|
57
|
+
data_dicts.append({"image": tomogram})
|
|
58
|
+
|
|
59
|
+
return data_dicts
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def create_predict_dataloader(
|
|
63
|
+
root,
|
|
64
|
+
voxel_spacing: float,
|
|
65
|
+
tomo_algorithm: str,
|
|
66
|
+
runIDs: str = None,
|
|
67
|
+
):
|
|
68
|
+
"""
|
|
69
|
+
Create a dataloader for prediction data.
|
|
70
|
+
"""
|
|
71
|
+
# define pre transforms
|
|
72
|
+
pre_transforms = Compose(
|
|
73
|
+
[ EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
|
|
74
|
+
NormalizeIntensityd(keys=["image"]),
|
|
75
|
+
])
|
|
76
|
+
|
|
77
|
+
# Split trainRunIDs, validateRunIDs, testRunIDs
|
|
78
|
+
if runIDs is None:
|
|
79
|
+
runIDs = [run.name for run in root.runs]
|
|
80
|
+
test_files = load_predict_data(root, runIDs, voxel_spacing, tomo_algorithm)
|
|
81
|
+
|
|
82
|
+
bs = min(len(test_files), os.cpu_count() or 4)
|
|
83
|
+
test_ds = CacheDataset(data=test_files, transform=pre_transforms)
|
|
84
|
+
test_loader = DataLoader(test_ds,
|
|
85
|
+
batch_size=bs,
|
|
86
|
+
shuffle=False,
|
|
87
|
+
num_workers=bs,
|
|
88
|
+
pin_memory=torch.cuda.is_available())
|
|
89
|
+
return test_loader, test_ds
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def adjust_to_multiple(value, multiple = 16):
|
|
93
|
+
"""
|
|
94
|
+
Adjust a value to be a multiple of the specified number.
|
|
95
|
+
"""
|
|
96
|
+
return int((value // multiple) * multiple)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def get_input_dimensions(dataset, crop_size: int):
|
|
100
|
+
"""
|
|
101
|
+
Get input dimensions for the dataset.
|
|
102
|
+
"""
|
|
103
|
+
nx = dataset[0]['image'].shape[1]
|
|
104
|
+
if crop_size > nx:
|
|
105
|
+
first_dim = adjust_to_multiple(nx/2)
|
|
106
|
+
return first_dim, crop_size, crop_size
|
|
107
|
+
else:
|
|
108
|
+
return crop_size, crop_size, crop_size
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def get_num_classes(copick_config_path: str):
|
|
112
|
+
"""
|
|
113
|
+
Get the number of classes from a CoPick configuration.
|
|
114
|
+
"""
|
|
115
|
+
root = copick.from_file(copick_config_path)
|
|
116
|
+
return len(root.pickable_objects) + 1
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def split_multiclass_dataset(runIDs,
|
|
120
|
+
train_ratio: float = 0.7,
|
|
121
|
+
val_ratio: float = 0.15,
|
|
122
|
+
test_ratio: float = 0.15,
|
|
123
|
+
return_test_dataset: bool = True,
|
|
124
|
+
random_state: int = 42):
|
|
125
|
+
"""
|
|
126
|
+
Splits a given dataset into three subsets: training, validation, and testing. If the dataset
|
|
127
|
+
has categories (as tuples), splits are balanced across all categories. If the dataset is a 1D
|
|
128
|
+
list, it is split without categorization.
|
|
129
|
+
|
|
130
|
+
Parameters:
|
|
131
|
+
- runIDs: A list of items to split. It can be a 1D list or a list of tuples (category, value).
|
|
132
|
+
- train_ratio: Proportion of the dataset for training.
|
|
133
|
+
- val_ratio: Proportion of the dataset for validation.
|
|
134
|
+
- test_ratio: Proportion of the dataset for testing.
|
|
135
|
+
- return_test_dataset: Whether to return the test dataset.
|
|
136
|
+
- random_state: Random state for reproducibility.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
- trainRunIDs: Training subset.
|
|
140
|
+
- valRunIDs: Validation subset.
|
|
141
|
+
- testRunIDs: Testing subset (if return_test_dataset is True, otherwise None).
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
# Ensure the ratios add up to 1
|
|
145
|
+
assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must add up to 1.0"
|
|
146
|
+
|
|
147
|
+
# Check if the dataset has categories
|
|
148
|
+
if isinstance(runIDs[0], tuple) and len(runIDs[0]) == 2:
|
|
149
|
+
# Group by category
|
|
150
|
+
grouped = defaultdict(list)
|
|
151
|
+
for item in runIDs:
|
|
152
|
+
grouped[item[0]].append(item)
|
|
153
|
+
|
|
154
|
+
# Split each category
|
|
155
|
+
trainRunIDs, valRunIDs, testRunIDs = [], [], []
|
|
156
|
+
for category, items in grouped.items():
|
|
157
|
+
# Shuffle for randomness
|
|
158
|
+
random.shuffle(items)
|
|
159
|
+
# Split into train and remaining
|
|
160
|
+
train_items, remaining = train_test_split(items, test_size=(1 - train_ratio), random_state=random_state)
|
|
161
|
+
trainRunIDs.extend(train_items)
|
|
162
|
+
|
|
163
|
+
if return_test_dataset:
|
|
164
|
+
# Split remaining into validation and test
|
|
165
|
+
val_items, test_items = train_test_split(
|
|
166
|
+
remaining,
|
|
167
|
+
test_size=(test_ratio / (val_ratio + test_ratio)),
|
|
168
|
+
random_state=random_state,
|
|
169
|
+
)
|
|
170
|
+
valRunIDs.extend(val_items)
|
|
171
|
+
testRunIDs.extend(test_items)
|
|
172
|
+
else:
|
|
173
|
+
valRunIDs.extend(remaining)
|
|
174
|
+
testRunIDs = []
|
|
175
|
+
else:
|
|
176
|
+
# If no categories, split as a 1D list
|
|
177
|
+
trainRunIDs, remaining = train_test_split(runIDs, test_size=(1 - train_ratio), random_state=random_state)
|
|
178
|
+
if return_test_dataset:
|
|
179
|
+
valRunIDs, testRunIDs = train_test_split(
|
|
180
|
+
remaining,
|
|
181
|
+
test_size=(test_ratio / (val_ratio + test_ratio)),
|
|
182
|
+
random_state=random_state,
|
|
183
|
+
)
|
|
184
|
+
else:
|
|
185
|
+
valRunIDs = remaining
|
|
186
|
+
testRunIDs = []
|
|
187
|
+
|
|
188
|
+
return trainRunIDs, valRunIDs, testRunIDs
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def load_copick_config(path: str):
|
|
192
|
+
"""
|
|
193
|
+
Load a CoPick configuration from file.
|
|
194
|
+
"""
|
|
195
|
+
if os.path.isfile(path):
|
|
196
|
+
root = copick.from_file(path)
|
|
197
|
+
else:
|
|
198
|
+
raise FileNotFoundError(f"Copick Config Path does not exist: {path}")
|
|
199
|
+
|
|
200
|
+
return root
|
octopi/datasets/mixup.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from monai.transforms import Transform
|
|
2
|
+
from torch.distributions import Beta
|
|
3
|
+
from torch import nn
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
class MixupTransformd(Transform):
|
|
8
|
+
"""
|
|
9
|
+
A dictionary-based wrapper for Mixup augmentation that applies to batches.
|
|
10
|
+
This needs to be applied after batching, typically directly in the training loop.
|
|
11
|
+
"""
|
|
12
|
+
def __init__(self, keys=("image", "label"), mix_beta=0.2, mixadd=False, prob=0.5):
|
|
13
|
+
self.keys = keys
|
|
14
|
+
self.mixup = Mixup(mix_beta=mix_beta, mixadd=mixadd)
|
|
15
|
+
self.prob = prob
|
|
16
|
+
|
|
17
|
+
def __call__(self, data):
|
|
18
|
+
d = dict(data)
|
|
19
|
+
if np.random.random() < self.prob: # Apply with probability
|
|
20
|
+
d[self.keys[0]], d[self.keys[1]] = self.mixup(d[self.keys[0]], d[self.keys[1]])
|
|
21
|
+
return d
|
|
22
|
+
|
|
23
|
+
class Mixup(nn.Module):
|
|
24
|
+
def __init__(self, mix_beta, mixadd=False):
|
|
25
|
+
|
|
26
|
+
super(Mixup, self).__init__()
|
|
27
|
+
self.beta_distribution = Beta(mix_beta, mix_beta)
|
|
28
|
+
self.mixadd = mixadd
|
|
29
|
+
|
|
30
|
+
def forward(self, X, Y, Z=None):
|
|
31
|
+
|
|
32
|
+
bs = X.shape[0]
|
|
33
|
+
n_dims = len(X.shape)
|
|
34
|
+
perm = torch.randperm(bs)
|
|
35
|
+
coeffs = self.beta_distribution.rsample(torch.Size((bs,))).to(X.device)
|
|
36
|
+
X_coeffs = coeffs.view((-1,) + (1,)*(X.ndim-1))
|
|
37
|
+
Y_coeffs = coeffs.view((-1,) + (1,)*(Y.ndim-1))
|
|
38
|
+
|
|
39
|
+
X = X_coeffs * X + (1-X_coeffs) * X[perm]
|
|
40
|
+
|
|
41
|
+
if self.mixadd:
|
|
42
|
+
Y = (Y + Y[perm]).clip(0, 1)
|
|
43
|
+
else:
|
|
44
|
+
Y = Y_coeffs * Y + (1 - Y_coeffs) * Y[perm]
|
|
45
|
+
|
|
46
|
+
if Z:
|
|
47
|
+
return X, Y, Z
|
|
48
|
+
|
|
49
|
+
return X, Y
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
from octopi.datasets import dataset, augment, cached_datset
|
|
2
|
+
from octopi.datasets.generators import TrainLoaderManager
|
|
3
|
+
from monai.data import DataLoader, SmartCacheDataset, CacheDataset, Dataset
|
|
4
|
+
from octopi.datasets import io
|
|
5
|
+
import multiprocess as mp
|
|
6
|
+
from typing import List
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
import torch, gc
|
|
9
|
+
|
|
10
|
+
class MultiConfigTrainLoaderManager(TrainLoaderManager):
|
|
11
|
+
|
|
12
|
+
def __init__(self,
|
|
13
|
+
configs: dict, # Dictionary of session names and config paths
|
|
14
|
+
target_name: str,
|
|
15
|
+
target_session_id: str = None,
|
|
16
|
+
target_user_id: str = None,
|
|
17
|
+
voxel_size: float = 10,
|
|
18
|
+
tomo_algorithm: List[str] = ['wbp'],
|
|
19
|
+
tomo_batch_size: int = 15
|
|
20
|
+
):
|
|
21
|
+
"""
|
|
22
|
+
Initialize MultiConfigTrainLoaderManager with multiple configs.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
configs (list): List of config file paths.
|
|
26
|
+
Other arguments are inherited from TrainLoaderManager.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
# Initialize shared attributes manually (skip super().__init__ to avoid invalid config handling)
|
|
30
|
+
self.config = configs
|
|
31
|
+
self.roots = {name: io.load_copick_config(path) for name, path in configs.items()}
|
|
32
|
+
|
|
33
|
+
# Target and algorithm parameters
|
|
34
|
+
self.target_name = target_name
|
|
35
|
+
self.target_session_id = target_session_id
|
|
36
|
+
self.target_user_id = target_user_id
|
|
37
|
+
self.voxel_size = voxel_size
|
|
38
|
+
self.tomo_algorithm = tomo_algorithm
|
|
39
|
+
|
|
40
|
+
# Data management parameters
|
|
41
|
+
self.tomo_batch_size = tomo_batch_size
|
|
42
|
+
self.reload_training_dataset = True
|
|
43
|
+
self.reload_validation_dataset = True
|
|
44
|
+
self.val_loader = None
|
|
45
|
+
self.train_loader = None
|
|
46
|
+
|
|
47
|
+
# Initialize Run IDs placeholder
|
|
48
|
+
self.myRunIDs = {}
|
|
49
|
+
|
|
50
|
+
# Initialize the input dimensions
|
|
51
|
+
self.nx = None
|
|
52
|
+
self.ny = None
|
|
53
|
+
self.nz = None
|
|
54
|
+
|
|
55
|
+
def get_available_runIDs(self):
|
|
56
|
+
"""
|
|
57
|
+
Identify and return a combined list of run IDs with available segmentations
|
|
58
|
+
across all configured CoPick projects.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
List of tuples: [(session_name, run_name), ...]
|
|
62
|
+
"""
|
|
63
|
+
available_runIDs = []
|
|
64
|
+
for name, root in self.roots.items():
|
|
65
|
+
runIDs = [run.name for run in root.runs]
|
|
66
|
+
for run in runIDs:
|
|
67
|
+
run = root.get_run(run)
|
|
68
|
+
seg = run.get_segmentations(
|
|
69
|
+
name=self.target_name,
|
|
70
|
+
session_id=self.target_session_id,
|
|
71
|
+
user_id=self.target_user_id,
|
|
72
|
+
voxel_size=float(self.voxel_size)
|
|
73
|
+
)
|
|
74
|
+
if len(seg) > 0:
|
|
75
|
+
available_runIDs.append((name, run.name)) # Include session name for disambiguation
|
|
76
|
+
|
|
77
|
+
# If No Segmentations are Found, Inform the User
|
|
78
|
+
if len(available_runIDs) == 0:
|
|
79
|
+
print(
|
|
80
|
+
f"[Error] No segmentations found for the target query:\n"
|
|
81
|
+
f"TargetName: {self.target_name}, UserID: {self.target_user_id}, "
|
|
82
|
+
f"SessionID: {self.target_session_id}\n"
|
|
83
|
+
f"Please check the target name, user ID, and session ID.\n"
|
|
84
|
+
)
|
|
85
|
+
exit()
|
|
86
|
+
|
|
87
|
+
return available_runIDs
|
|
88
|
+
|
|
89
|
+
def get_data_splits(self,
|
|
90
|
+
trainRunIDs: str = None,
|
|
91
|
+
validateRunIDs: str = None,
|
|
92
|
+
train_ratio: float = 0.8,
|
|
93
|
+
val_ratio: float = 0.1,
|
|
94
|
+
test_ratio: float = 0.1,
|
|
95
|
+
create_test_dataset: bool = True):
|
|
96
|
+
"""
|
|
97
|
+
Override to handle run IDs as (session_name, run_name) tuples.
|
|
98
|
+
"""
|
|
99
|
+
# Use the get_available_runIDs method to handle multiple projects
|
|
100
|
+
runIDs = self.get_available_runIDs()
|
|
101
|
+
return super().get_data_splits(trainRunIDs = runIDs,
|
|
102
|
+
train_ratio = train_ratio,
|
|
103
|
+
val_ratio = val_ratio,
|
|
104
|
+
test_ratio = test_ratio,
|
|
105
|
+
create_test_dataset = create_test_dataset)
|
|
106
|
+
|
|
107
|
+
def _initialize_train_iterators(self):
|
|
108
|
+
"""
|
|
109
|
+
Initialize the training data iterators with multi-config support.
|
|
110
|
+
"""
|
|
111
|
+
self.padded_train_list = self._get_padded_list(self.myRunIDs['train'], self.train_batch_size)
|
|
112
|
+
self.train_data_iter = iter(self._get_data_batches(self.padded_train_list, self.train_batch_size))
|
|
113
|
+
|
|
114
|
+
def _initialize_val_iterators(self):
|
|
115
|
+
"""
|
|
116
|
+
Initialize the validation data iterators with multi-config support.
|
|
117
|
+
"""
|
|
118
|
+
self.padded_val_list = self._get_padded_list(self.myRunIDs['validate'], self.val_batch_size)
|
|
119
|
+
self.val_data_iter = iter(self._get_data_batches(self.padded_val_list, self.val_batch_size))
|
|
120
|
+
|
|
121
|
+
def _load_data(self, runIDs):
|
|
122
|
+
"""
|
|
123
|
+
Load data from multiple CoPick projects for given run IDs.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
runIDs (list): List of (session_name, run_name) tuples.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
List: Combined dataset for the specified run IDs.
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
data = []
|
|
133
|
+
for session_name, run_name in tqdm(runIDs):
|
|
134
|
+
root = self.roots[session_name]
|
|
135
|
+
data.extend(io.load_training_data(
|
|
136
|
+
root, [run_name], self.voxel_size, self.tomo_algorithm,
|
|
137
|
+
self.target_name, self.target_session_id, self.target_user_id,
|
|
138
|
+
progress_update=False ))
|
|
139
|
+
self._check_max_label_value(data)
|
|
140
|
+
return data
|
|
141
|
+
|
|
142
|
+
def create_train_dataloaders(self, *args, **kwargs):
|
|
143
|
+
"""
|
|
144
|
+
Override data loading to fetch from multiple projects.
|
|
145
|
+
"""
|
|
146
|
+
my_crop_size = kwargs.get("crop_size", 96)
|
|
147
|
+
my_num_samples = kwargs.get("num_samples", 128)
|
|
148
|
+
|
|
149
|
+
# If reloads are disabled and loaders already exist, reuse them
|
|
150
|
+
if self.reload_frequency < 0 and (self.train_loader is not None) and (self.val_loader is not None):
|
|
151
|
+
return self.train_loader, self.val_loader
|
|
152
|
+
|
|
153
|
+
# Estimate Max Number of Threads with mp.cpu_count
|
|
154
|
+
n_procs = min(mp.cpu_count(), 4)
|
|
155
|
+
|
|
156
|
+
if self.train_loader is None:
|
|
157
|
+
# Fetch the next batch of run IDs
|
|
158
|
+
trainRunIDs = self._extract_run_ids('train_data_iter', self._initialize_train_iterators)
|
|
159
|
+
train_files = self._load_data(trainRunIDs)
|
|
160
|
+
|
|
161
|
+
# # Create the cached dataset with non-random transforms
|
|
162
|
+
train_ds = SmartCacheDataset(data=train_files, transform=augment.get_transforms(), cache_rate=0.5)
|
|
163
|
+
|
|
164
|
+
# # Delete the training files to free memory
|
|
165
|
+
train_files = None
|
|
166
|
+
gc.collect()
|
|
167
|
+
|
|
168
|
+
# Create the cached dataset with non-random transforms
|
|
169
|
+
# train_ds = cached_datset.MultiConfigCacheDataset(
|
|
170
|
+
# self, trainRunIDs, transform=augment.get_transforms(), cache_rate=1.0
|
|
171
|
+
# )
|
|
172
|
+
|
|
173
|
+
# I need to read (nx,ny,nz) and scale the crop size to make sure it isnt larger than nx.
|
|
174
|
+
if self.nx is None: (self.nx,self.ny,self.nz) = train_ds[0]['image'].shape[1:]
|
|
175
|
+
self.input_dim = io.get_input_dimensions(train_ds, my_crop_size)
|
|
176
|
+
|
|
177
|
+
# Wrap the cached dataset to apply random transforms during iteration
|
|
178
|
+
self.dynamic_train_dataset = dataset.DynamicDataset(
|
|
179
|
+
data=train_ds,
|
|
180
|
+
transform=augment.get_random_transforms(self.input_dim, my_num_samples, self.Nclasses)
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
self.train_loader = DataLoader(
|
|
184
|
+
self.dynamic_train_dataset,
|
|
185
|
+
batch_size=1,
|
|
186
|
+
shuffle=True,
|
|
187
|
+
num_workers=n_procs,
|
|
188
|
+
pin_memory=torch.cuda.is_available(),
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
else:
|
|
192
|
+
# Fetch the next batch of run IDs
|
|
193
|
+
trainRunIDs = self._extract_run_ids('train_data_iter', self._initialize_train_iterators)
|
|
194
|
+
train_files = self._load_data(trainRunIDs)
|
|
195
|
+
train_ds = CacheDataset(data=train_files, transform=augment.get_transforms(), cache_rate=1.0)
|
|
196
|
+
self.dynamic_train_dataset.update_data(train_ds)
|
|
197
|
+
|
|
198
|
+
# We Only Need to Reload the Validation Dataset if the Total Number of Runs is larger than
|
|
199
|
+
# the tomo batch size
|
|
200
|
+
if self.val_loader is None:
|
|
201
|
+
|
|
202
|
+
validateRunIDs = self._extract_run_ids('val_data_iter', self._initialize_val_iterators)
|
|
203
|
+
val_files = self._load_data(validateRunIDs)
|
|
204
|
+
|
|
205
|
+
# # Create validation dataset
|
|
206
|
+
val_ds = SmartCacheDataset(data=val_files, transform=augment.get_transforms(), cache_rate=1.0)
|
|
207
|
+
|
|
208
|
+
# # Delete the validation files to free memory
|
|
209
|
+
val_files = None
|
|
210
|
+
gc.collect()
|
|
211
|
+
|
|
212
|
+
# Create the cached dataset with non-random transforms
|
|
213
|
+
# val_ds = cached_datset.MultiConfigCacheDataset(
|
|
214
|
+
# self, validateRunIDs, transform=augment.get_transforms(), cache_rate=1.0
|
|
215
|
+
# )
|
|
216
|
+
|
|
217
|
+
# # I need to read (nx,ny,nz) and scale the crop size to make sure it isnt larger than nx.
|
|
218
|
+
# if self.nx is None:
|
|
219
|
+
# (self.nx,self.ny,self.nz) = val_ds[0]['image'].shape[1:]
|
|
220
|
+
|
|
221
|
+
# if crop_size > self.nx: self.input_dim = (self.nx, crop_size, crop_size)
|
|
222
|
+
# else: self.input_dim = (crop_size, crop_size, crop_size)
|
|
223
|
+
|
|
224
|
+
# Wrap the cached dataset to apply random transforms during iteration
|
|
225
|
+
self.dynamic_validation_dataset = dataset.DynamicDataset( data=val_ds )
|
|
226
|
+
|
|
227
|
+
# Create validation DataLoader
|
|
228
|
+
self.val_loader = DataLoader(
|
|
229
|
+
self.dynamic_validation_dataset,
|
|
230
|
+
batch_size=1,
|
|
231
|
+
num_workers=n_procs,
|
|
232
|
+
pin_memory=torch.cuda.is_available(),
|
|
233
|
+
shuffle=False, # Ensure the data order remains consistent
|
|
234
|
+
)
|
|
235
|
+
else:
|
|
236
|
+
validateRunIDs = self._extract_run_ids('val_data_iter', self._initialize_val_iterators)
|
|
237
|
+
val_files = self._load_data(validateRunIDs)
|
|
238
|
+
|
|
239
|
+
val_ds = CacheDataset(data=val_files, transform=augment.get_transforms(), cache_rate=1.0)
|
|
240
|
+
self.dynamic_validation_dataset.update_data(val_ds)
|
|
241
|
+
|
|
242
|
+
return self.train_loader, self.val_loader
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def tmp_return_datasets(self):
|
|
246
|
+
trainRunIDs = self._extract_run_ids('train_data_iter', self._initialize_train_iterators)
|
|
247
|
+
train_files = self._load_data(trainRunIDs)
|
|
248
|
+
|
|
249
|
+
validateRunIDs = self._extract_run_ids('val_data_iter', self._initialize_val_iterators)
|
|
250
|
+
val_files = self._load_data(validateRunIDs)
|
|
251
|
+
|
|
252
|
+
return train_files, val_files
|
|
File without changes
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
from octopi.utils import parsers
|
|
2
|
+
import rich_click as click
|
|
3
|
+
from typing import List, Tuple
|
|
4
|
+
|
|
5
|
+
def model_parameters(octopi: bool = False):
|
|
6
|
+
"""Decorator for adding model parameters"""
|
|
7
|
+
def decorator(f):
|
|
8
|
+
f = click.option("-dim", "--dim-in", type=int, default=96,
|
|
9
|
+
help="Input dimension for the UNet model")(f)
|
|
10
|
+
f = click.option("--res-units", type=int, default=1,
|
|
11
|
+
help="Number of residual units in the UNet")(f)
|
|
12
|
+
f = click.option("-s", "--strides", type=str, default='2,2,1',
|
|
13
|
+
callback=lambda ctx, param, value: parsers.parse_int_list(value) if value else value,
|
|
14
|
+
help="List of stride sizes")(f)
|
|
15
|
+
f = click.option("-ch", "--channels", type=str, default='32,64,96,96',
|
|
16
|
+
callback=lambda ctx, param, value: parsers.parse_int_list(value) if value else value,
|
|
17
|
+
help="List of channel sizes")(f)
|
|
18
|
+
return f
|
|
19
|
+
return decorator
|
|
20
|
+
|
|
21
|
+
def inference_model_parameters():
|
|
22
|
+
"""Decorator for adding inference model parameters"""
|
|
23
|
+
def decorator(f):
|
|
24
|
+
f = click.option("-mw", "--model-weights", type=click.Path(exists=True), required=True,
|
|
25
|
+
help="Path to the model weights file")(f)
|
|
26
|
+
f = click.option("-mc", "--model-config", type=click.Path(exists=True), required=True,
|
|
27
|
+
help="Path to the model configuration file")(f)
|
|
28
|
+
return f
|
|
29
|
+
return decorator
|
|
30
|
+
|
|
31
|
+
def train_parameters(octopi: bool = False):
|
|
32
|
+
"""Decorator for adding training parameters"""
|
|
33
|
+
def decorator(f):
|
|
34
|
+
if octopi:
|
|
35
|
+
f = click.option("-nt", "--num-trials", type=int, default=10,
|
|
36
|
+
help="Number of trials for architecture search")(f)
|
|
37
|
+
else:
|
|
38
|
+
f = click.option("-o", "--model-save-path", type=click.Path(), default='results',
|
|
39
|
+
help="Path to model save directory")(f)
|
|
40
|
+
f = click.option("--tversky-alpha", type=float, default=0.3,
|
|
41
|
+
help="Alpha parameter for the Tversky loss")(f)
|
|
42
|
+
f = click.option("-lr", "--lr", type=float, default=1e-3,
|
|
43
|
+
help="Learning rate for the optimizer")(f)
|
|
44
|
+
f = click.option('-ncrops', "--num-tomo-crops", type=int, default=16,
|
|
45
|
+
help="Number of tomogram crops to use per patch")(f)
|
|
46
|
+
|
|
47
|
+
f = click.option("--best-metric", type=str, default='avg_f1',
|
|
48
|
+
help="Metric to Monitor for Determining Best Model. To track fBetaN, use fBetaN with N as the beta-value.")(f)
|
|
49
|
+
f = click.option('-ntomos', "--tomo-batch-size", type=int, default=15,
|
|
50
|
+
help="Number of tomograms to load per epoch for training")(f)
|
|
51
|
+
f = click.option("--val-interval", type=int, default=10,
|
|
52
|
+
help="Interval for validation metric calculations")(f)
|
|
53
|
+
f = click.option('-nepochs', "--num-epochs", type=int, default=1000,
|
|
54
|
+
help="Number of training epochs")(f)
|
|
55
|
+
return f
|
|
56
|
+
return decorator
|
|
57
|
+
|
|
58
|
+
def config_parameters(single_config: bool):
|
|
59
|
+
"""Decorator for adding config parameters"""
|
|
60
|
+
def decorator(f):
|
|
61
|
+
f = click.option("-vs", "--voxel-size", type=float, default=10,
|
|
62
|
+
help="Voxel size of tomograms used")(f)
|
|
63
|
+
if single_config:
|
|
64
|
+
f = click.option("-c", "--config", type=click.Path(exists=True), required=True,
|
|
65
|
+
help="Path to the configuration file")(f)
|
|
66
|
+
else:
|
|
67
|
+
f = click.option("-c", "--config", type=str, multiple=True, required=True,
|
|
68
|
+
help="Specify a single configuration path (/path/to/config.json) or multiple entries in the format session_name,/path/to/config.json. Use multiple --config entries for multiple sessions.")(f)
|
|
69
|
+
return f
|
|
70
|
+
return decorator
|
|
71
|
+
|
|
72
|
+
def inference_parameters():
|
|
73
|
+
"""Decorator for adding inference parameters"""
|
|
74
|
+
def decorator(f):
|
|
75
|
+
f = click.option('-runs', "--run-ids", type=str, default=None,
|
|
76
|
+
callback=lambda ctx, param, value: parsers.parse_list(value) if value else None,
|
|
77
|
+
help="List of run IDs for prediction, e.g., run1,run2 or [run1,run2]. If not provided, all available runs will be processed.")(f)
|
|
78
|
+
f = click.option('-ntomos', "--tomo-batch-size", type=int, default=25,
|
|
79
|
+
help="Batch size for tomogram processing")(f)
|
|
80
|
+
f = click.option('-seginfo', "--seg-info", type=str, default='predict,octopi,1',
|
|
81
|
+
callback=lambda ctx, param, value: parsers.parse_target(value) if value else value,
|
|
82
|
+
help='Information Query to save Segmentation predictions under (e.g., "name" or "name,user_id,session_id" - Default UserID is octopi and SessionID is 1')(f)
|
|
83
|
+
f = click.option('-alg', "--tomo-alg", type=str, default='wbp',
|
|
84
|
+
help="Tomogram algorithm used for produces segmentation prediction masks")(f)
|
|
85
|
+
return f
|
|
86
|
+
return decorator
|
|
87
|
+
|
|
88
|
+
def localize_parameters():
|
|
89
|
+
"""Decorator for adding localization parameters"""
|
|
90
|
+
def decorator(f):
|
|
91
|
+
f = click.option('-seginfo', "--seg-info", type=str, required=True,
|
|
92
|
+
help="Segmentation info")(f)
|
|
93
|
+
f = click.option("--pick-objects", type=str, required=True,
|
|
94
|
+
help="Pick objects")(f)
|
|
95
|
+
f = click.option("--pick-session-id", type=str, default="1",
|
|
96
|
+
help="Pick session ID")(f)
|
|
97
|
+
f = click.option('-m', "--method", type=str, default='watershed',
|
|
98
|
+
help="Localization method")(f)
|
|
99
|
+
f = click.option('-vs', "--voxel-size", type=int, default=10,
|
|
100
|
+
help="Voxel size")(f)
|
|
101
|
+
return f
|
|
102
|
+
return decorator
|
|
103
|
+
|
|
104
|
+
def slurm_parameters(base_job_name: str, gpus: int = 1):
|
|
105
|
+
"""Decorator for adding SLURM parameters"""
|
|
106
|
+
def decorator(f):
|
|
107
|
+
if gpus > 1:
|
|
108
|
+
f = click.option("--num-gpus", type=int, default=1,
|
|
109
|
+
help="Number of GPUs to use")(f)
|
|
110
|
+
if gpus > 0:
|
|
111
|
+
f = click.option("--gpu-constraint", type=click.Choice(['a6000', 'a100', 'h100', 'h200'], case_sensitive=False),
|
|
112
|
+
default='h100',
|
|
113
|
+
help="GPU constraint")(f)
|
|
114
|
+
f = click.option("--job-name", type=str, default=base_job_name,
|
|
115
|
+
help="Job name for SLURM job")(f)
|
|
116
|
+
f = click.option("--conda-env", type=click.Path(), default='/hpc/projects/group.czii/conda_environments/pyUNET/',
|
|
117
|
+
help="Path to Conda environment")(f)
|
|
118
|
+
return f
|
|
119
|
+
return decorator
|