octopi 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of octopi might be problematic. Click here for more details.
- octopi/__init__.py +0 -0
- octopi/datasets/__init__.py +0 -0
- octopi/datasets/augment.py +84 -0
- octopi/datasets/cached_datset.py +113 -0
- octopi/datasets/dataset.py +19 -0
- octopi/datasets/generators.py +429 -0
- octopi/datasets/mixup.py +49 -0
- octopi/datasets/multi_config_generator.py +253 -0
- octopi/entry_points/__init__.py +0 -0
- octopi/entry_points/common.py +80 -0
- octopi/entry_points/create_slurm_submission.py +243 -0
- octopi/entry_points/run_create_targets.py +281 -0
- octopi/entry_points/run_evaluate.py +65 -0
- octopi/entry_points/run_extract_mb_picks.py +141 -0
- octopi/entry_points/run_extract_midpoint.py +143 -0
- octopi/entry_points/run_localize.py +222 -0
- octopi/entry_points/run_optuna.py +139 -0
- octopi/entry_points/run_segment_predict.py +166 -0
- octopi/entry_points/run_train.py +201 -0
- octopi/extract/__init__.py +0 -0
- octopi/extract/localize.py +254 -0
- octopi/extract/membranebound_extract.py +262 -0
- octopi/extract/midpoint_extract.py +193 -0
- octopi/io.py +457 -0
- octopi/losses.py +86 -0
- octopi/main.py +101 -0
- octopi/models/AttentionUnet.py +56 -0
- octopi/models/MedNeXt.py +111 -0
- octopi/models/ModelTemplate.py +36 -0
- octopi/models/SegResNet.py +92 -0
- octopi/models/Unet.py +59 -0
- octopi/models/UnetPlusPlus.py +47 -0
- octopi/models/__init__.py +0 -0
- octopi/models/common.py +62 -0
- octopi/processing/__init__.py +0 -0
- octopi/processing/create_targets_from_picks.py +106 -0
- octopi/processing/downsample.py +129 -0
- octopi/processing/evaluate.py +289 -0
- octopi/processing/importers.py +213 -0
- octopi/processing/my_metrics.py +26 -0
- octopi/processing/segmentation_from_picks.py +167 -0
- octopi/processing/writers.py +102 -0
- octopi/pytorch/__init__.py +0 -0
- octopi/pytorch/hyper_search.py +243 -0
- octopi/pytorch/model_search_submitter.py +290 -0
- octopi/pytorch/segmentation.py +317 -0
- octopi/pytorch/trainer.py +438 -0
- octopi/pytorch_lightning/__init__.py +0 -0
- octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
- octopi/pytorch_lightning/train_pl.py +244 -0
- octopi/stopping_criteria.py +143 -0
- octopi/submit_slurm.py +95 -0
- octopi/utils.py +238 -0
- octopi/visualization_tools.py +201 -0
- octopi-1.0.dist-info/LICENSE +41 -0
- octopi-1.0.dist-info/METADATA +209 -0
- octopi-1.0.dist-info/RECORD +59 -0
- octopi-1.0.dist-info/WHEEL +4 -0
- octopi-1.0.dist-info/entry_points.txt +4 -0
octopi/io.py
ADDED
|
@@ -0,0 +1,457 @@
|
|
|
1
|
+
from monai.data import DataLoader, CacheDataset, Dataset
|
|
2
|
+
from monai.transforms import (
|
|
3
|
+
Compose,
|
|
4
|
+
NormalizeIntensityd,
|
|
5
|
+
EnsureChannelFirstd,
|
|
6
|
+
)
|
|
7
|
+
from sklearn.model_selection import train_test_split
|
|
8
|
+
import copick, torch, os, json, random
|
|
9
|
+
from collections import defaultdict
|
|
10
|
+
from octopi import utils
|
|
11
|
+
from typing import List
|
|
12
|
+
from tqdm import tqdm
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
##############################################################################################################################
|
|
16
|
+
|
|
17
|
+
def load_training_data(root,
|
|
18
|
+
runIDs: List[str],
|
|
19
|
+
voxel_spacing: float,
|
|
20
|
+
tomo_algorithm: str,
|
|
21
|
+
segmenation_name: str,
|
|
22
|
+
segmentation_session_id: str = None,
|
|
23
|
+
segmentation_user_id: str = None,
|
|
24
|
+
progress_update: bool = True):
|
|
25
|
+
|
|
26
|
+
data_dicts = []
|
|
27
|
+
# Use tqdm for progress tracking only if progress_update is True
|
|
28
|
+
iterable = tqdm(runIDs, desc="Loading Training Data") if progress_update else runIDs
|
|
29
|
+
for runID in iterable:
|
|
30
|
+
run = root.get_run(str(runID))
|
|
31
|
+
tomogram = get_tomogram_array(run, voxel_spacing, tomo_algorithm)
|
|
32
|
+
segmentation = get_segmentation_array(run,
|
|
33
|
+
voxel_spacing,
|
|
34
|
+
segmenation_name,
|
|
35
|
+
segmentation_session_id,
|
|
36
|
+
segmentation_user_id)
|
|
37
|
+
data_dicts.append({"image": tomogram, "label": segmentation})
|
|
38
|
+
|
|
39
|
+
return data_dicts
|
|
40
|
+
|
|
41
|
+
##############################################################################################################################
|
|
42
|
+
|
|
43
|
+
def load_predict_data(root,
|
|
44
|
+
runIDs: List[str],
|
|
45
|
+
voxel_spacing: float,
|
|
46
|
+
tomo_algorithm: str):
|
|
47
|
+
|
|
48
|
+
data_dicts = []
|
|
49
|
+
for runID in tqdm(runIDs):
|
|
50
|
+
run = root.get_run(str(runID))
|
|
51
|
+
tomogram = get_tomogram_array(run, voxel_spacing, tomo_algorithm)
|
|
52
|
+
data_dicts.append({"image": tomogram})
|
|
53
|
+
|
|
54
|
+
return data_dicts
|
|
55
|
+
|
|
56
|
+
##############################################################################################################################
|
|
57
|
+
|
|
58
|
+
def create_predict_dataloader(
|
|
59
|
+
root,
|
|
60
|
+
voxel_spacing: float,
|
|
61
|
+
tomo_algorithm: str,
|
|
62
|
+
runIDs: str = None,
|
|
63
|
+
):
|
|
64
|
+
|
|
65
|
+
# define pre transforms
|
|
66
|
+
pre_transforms = Compose(
|
|
67
|
+
[ EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
|
|
68
|
+
NormalizeIntensityd(keys=["image"]),
|
|
69
|
+
])
|
|
70
|
+
|
|
71
|
+
# Split trainRunIDs, validateRunIDs, testRunIDs
|
|
72
|
+
if runIDs is None:
|
|
73
|
+
runIDs = [run.name for run in root.runs]
|
|
74
|
+
test_files = load_predict_data(root, runIDs, voxel_spacing, tomo_algorithm)
|
|
75
|
+
|
|
76
|
+
bs = min( len(test_files), 4)
|
|
77
|
+
test_ds = CacheDataset(data=test_files, transform=pre_transforms)
|
|
78
|
+
test_loader = DataLoader(test_ds,
|
|
79
|
+
batch_size=bs,
|
|
80
|
+
shuffle=False,
|
|
81
|
+
num_workers=4,
|
|
82
|
+
pin_memory=torch.cuda.is_available())
|
|
83
|
+
return test_loader, test_ds
|
|
84
|
+
|
|
85
|
+
##############################################################################################################################
|
|
86
|
+
|
|
87
|
+
def get_tomogram_array(run,
|
|
88
|
+
voxel_size: float = 10,
|
|
89
|
+
tomo_type: str = 'wbp',
|
|
90
|
+
raise_error: bool = True):
|
|
91
|
+
|
|
92
|
+
voxel_spacing_obj = run.get_voxel_spacing(voxel_size)
|
|
93
|
+
|
|
94
|
+
if voxel_spacing_obj is None:
|
|
95
|
+
# Query Avaiable Voxel Spacings
|
|
96
|
+
availableVoxelSpacings = [tomo.voxel_size for tomo in run.voxel_spacings]
|
|
97
|
+
|
|
98
|
+
# Report to the user which voxel spacings they can use
|
|
99
|
+
message = (f"\n[Warning] No tomogram found for {run.name} with voxel size {voxel_size} and tomogram type {tomo_type}"
|
|
100
|
+
f"\nAvailable spacings are: {', '.join(map(str, availableVoxelSpacings))}\n" )
|
|
101
|
+
if raise_error:
|
|
102
|
+
raise ValueError(message)
|
|
103
|
+
else:
|
|
104
|
+
print(message)
|
|
105
|
+
return None
|
|
106
|
+
|
|
107
|
+
tomogram = voxel_spacing_obj.get_tomogram(tomo_type)
|
|
108
|
+
if tomogram is None:
|
|
109
|
+
# Get available algorithms
|
|
110
|
+
availableAlgorithms = [tomo.tomo_type for tomo in run.get_voxel_spacing(voxel_size).tomograms]
|
|
111
|
+
|
|
112
|
+
# Report to the user which algorithms are available
|
|
113
|
+
message = (f"\n[Warning] No tomogram found for {run.name} with voxel size {voxel_size} and tomogram type {tomo_type}"
|
|
114
|
+
f"\nAvailable algorithms are: {', '.join(availableAlgorithms)}\n")
|
|
115
|
+
if raise_error:
|
|
116
|
+
raise ValueError(message)
|
|
117
|
+
else:
|
|
118
|
+
print(message)
|
|
119
|
+
return None
|
|
120
|
+
|
|
121
|
+
return tomogram.numpy().astype(np.float32)
|
|
122
|
+
|
|
123
|
+
##############################################################################################################################
|
|
124
|
+
|
|
125
|
+
def get_segmentation_array(run,
|
|
126
|
+
voxel_spacing: float,
|
|
127
|
+
segmentation_name: str,
|
|
128
|
+
session_id=None,
|
|
129
|
+
user_id=None,
|
|
130
|
+
raise_error: bool = True):
|
|
131
|
+
|
|
132
|
+
seg = run.get_segmentations(name=segmentation_name,
|
|
133
|
+
session_id = session_id,
|
|
134
|
+
user_id = user_id,
|
|
135
|
+
voxel_size=float(voxel_spacing))
|
|
136
|
+
|
|
137
|
+
# No Segmentations Are Available, Result in Error
|
|
138
|
+
if len(seg) == 0:
|
|
139
|
+
# Get all available segmentations with their metadata
|
|
140
|
+
available_segs = run.get_segmentations(voxel_size=voxel_spacing)
|
|
141
|
+
seg_info = [(s.name, s.user_id, s.session_id) for s in available_segs]
|
|
142
|
+
|
|
143
|
+
# Format the information for display
|
|
144
|
+
seg_details = [f"(name: {name}, user_id: {uid}, session_id: {sid})"
|
|
145
|
+
for name, uid, sid in seg_info]
|
|
146
|
+
|
|
147
|
+
message = ( f'\nNo segmentation found matching:\n'
|
|
148
|
+
f' name: {segmentation_name}, user_id: {user_id}, session_id: {session_id}\n'
|
|
149
|
+
f'Available segmentations in {run.name} are:\n ' +
|
|
150
|
+
'\n '.join(seg_details) )
|
|
151
|
+
if raise_error:
|
|
152
|
+
raise ValueError(message)
|
|
153
|
+
else:
|
|
154
|
+
print(message)
|
|
155
|
+
return None
|
|
156
|
+
|
|
157
|
+
# No Segmentations Are Available, Result in Error
|
|
158
|
+
if len(seg) > 1:
|
|
159
|
+
print(f'[Warning] More Than 1 Segmentation is Available for the Query Information. '
|
|
160
|
+
f'Available Segmentations are: {seg} '
|
|
161
|
+
f'Defaulting to Loading: {seg[0]}\n')
|
|
162
|
+
seg = seg[0]
|
|
163
|
+
|
|
164
|
+
return seg.numpy().astype(np.int8)
|
|
165
|
+
|
|
166
|
+
##############################################################################################################################
|
|
167
|
+
|
|
168
|
+
def get_copick_coordinates(run, # CoPick run object containing the segmentation data
|
|
169
|
+
name: str, # Name of the object or protein for which coordinates are being extracted
|
|
170
|
+
user_id: str, # Identifier of the user that generated the picks
|
|
171
|
+
session_id: str = None, # Identifier of the session that generated the picks
|
|
172
|
+
voxel_size: float = 10, # Voxel size of the tomogram, used for scaling the coordinates
|
|
173
|
+
raise_error: bool = True):
|
|
174
|
+
|
|
175
|
+
# Retrieve the pick points associated with the specified object and user ID
|
|
176
|
+
picks = run.get_picks(object_name=name, user_id=user_id, session_id=session_id)
|
|
177
|
+
|
|
178
|
+
if len(picks) == 0:
|
|
179
|
+
# Get all available segmentations with their metadata
|
|
180
|
+
|
|
181
|
+
available_picks = run.get_picks()
|
|
182
|
+
picks_info = [(s.pickable_object_name, s.user_id, s.session_id) for s in available_picks]
|
|
183
|
+
|
|
184
|
+
# Format the information for display
|
|
185
|
+
picks_details = [f"(name: {name}, user_id: {uid}, session_id: {sid})"
|
|
186
|
+
for name, uid, sid in picks_info]
|
|
187
|
+
|
|
188
|
+
message = ( f'\nNo picks found matching:\n'
|
|
189
|
+
f' name: {name}, user_id: {user_id}, session_id: {session_id}\n'
|
|
190
|
+
f'Available picks are:\n '
|
|
191
|
+
+ '\n '.join(picks_details) )
|
|
192
|
+
if raise_error:
|
|
193
|
+
raise ValueError(message)
|
|
194
|
+
else:
|
|
195
|
+
print(message)
|
|
196
|
+
return None
|
|
197
|
+
elif len(picks) > 1:
|
|
198
|
+
# Format pick information for display
|
|
199
|
+
picks_info = [(p.pickable_object_name, p.user_id, p.session_id) for p in picks]
|
|
200
|
+
picks_details = [f"(name: {name}, user_id: {uid}, session_id: {sid})"
|
|
201
|
+
for name, uid, sid in picks_info]
|
|
202
|
+
|
|
203
|
+
print(f'[Warning] More than 1 pick is available for the query information.'
|
|
204
|
+
f'\nAvailable picks are:\n ' +
|
|
205
|
+
'\n '.join(picks_details) +
|
|
206
|
+
f'\nDefaulting to loading:\n {picks[0]}\n')
|
|
207
|
+
points = picks[0].points
|
|
208
|
+
|
|
209
|
+
# Initialize an array to store the coordinates
|
|
210
|
+
nPoints = len(picks[0].points) # Number of points retrieved
|
|
211
|
+
coordinates = np.zeros([len(picks[0].points), 3]) # Create an empty array to hold the (z, y, x) coordinates
|
|
212
|
+
|
|
213
|
+
# Iterate over all points and convert their locations to coordinates in voxel space
|
|
214
|
+
for ii in range(nPoints):
|
|
215
|
+
coordinates[ii,] = [points[ii].location.z / voxel_size, # Scale z-coordinate by voxel size
|
|
216
|
+
points[ii].location.y / voxel_size, # Scale y-coordinate by voxel size
|
|
217
|
+
points[ii].location.x / voxel_size] # Scale x-coordinate by voxel size
|
|
218
|
+
|
|
219
|
+
# Return the array of coordinates
|
|
220
|
+
return coordinates
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
##############################################################################################################################
|
|
224
|
+
|
|
225
|
+
def adjust_to_multiple(value, multiple = 16):
|
|
226
|
+
return int((value // multiple) * multiple)
|
|
227
|
+
|
|
228
|
+
def get_input_dimensions(dataset, crop_size: int):
|
|
229
|
+
nx = dataset[0]['image'].shape[1]
|
|
230
|
+
if crop_size > nx:
|
|
231
|
+
first_dim = adjust_to_multiple(nx/2)
|
|
232
|
+
return first_dim, crop_size, crop_size
|
|
233
|
+
else:
|
|
234
|
+
return crop_size, crop_size, crop_size
|
|
235
|
+
|
|
236
|
+
def get_num_classes(copick_config_path: str):
|
|
237
|
+
|
|
238
|
+
root = copick.from_file(copick_config_path)
|
|
239
|
+
return len(root.pickable_objects) + 1
|
|
240
|
+
|
|
241
|
+
def split_multiclass_dataset(runIDs,
|
|
242
|
+
train_ratio: float = 0.7,
|
|
243
|
+
val_ratio: float = 0.15,
|
|
244
|
+
test_ratio: float = 0.15,
|
|
245
|
+
return_test_dataset: bool = True,
|
|
246
|
+
random_state: int = 42):
|
|
247
|
+
"""
|
|
248
|
+
Splits a given dataset into three subsets: training, validation, and testing. If the dataset
|
|
249
|
+
has categories (as tuples), splits are balanced across all categories. If the dataset is a 1D
|
|
250
|
+
list, it is split without categorization.
|
|
251
|
+
|
|
252
|
+
Parameters:
|
|
253
|
+
- runIDs: A list of items to split. It can be a 1D list or a list of tuples (category, value).
|
|
254
|
+
- train_ratio: Proportion of the dataset for training.
|
|
255
|
+
- val_ratio: Proportion of the dataset for validation.
|
|
256
|
+
- test_ratio: Proportion of the dataset for testing.
|
|
257
|
+
- return_test_dataset: Whether to return the test dataset.
|
|
258
|
+
- random_state: Random state for reproducibility.
|
|
259
|
+
|
|
260
|
+
Returns:
|
|
261
|
+
- trainRunIDs: Training subset.
|
|
262
|
+
- valRunIDs: Validation subset.
|
|
263
|
+
- testRunIDs: Testing subset (if return_test_dataset is True, otherwise None).
|
|
264
|
+
"""
|
|
265
|
+
|
|
266
|
+
# Ensure the ratios add up to 1
|
|
267
|
+
assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must add up to 1.0"
|
|
268
|
+
|
|
269
|
+
# Check if the dataset has categories
|
|
270
|
+
if isinstance(runIDs[0], tuple) and len(runIDs[0]) == 2:
|
|
271
|
+
# Group by category
|
|
272
|
+
grouped = defaultdict(list)
|
|
273
|
+
for item in runIDs:
|
|
274
|
+
grouped[item[0]].append(item)
|
|
275
|
+
|
|
276
|
+
# Split each category
|
|
277
|
+
trainRunIDs, valRunIDs, testRunIDs = [], [], []
|
|
278
|
+
for category, items in grouped.items():
|
|
279
|
+
# Shuffle for randomness
|
|
280
|
+
random.shuffle(items)
|
|
281
|
+
# Split into train and remaining
|
|
282
|
+
train_items, remaining = train_test_split(items, test_size=(1 - train_ratio), random_state=random_state)
|
|
283
|
+
trainRunIDs.extend(train_items)
|
|
284
|
+
|
|
285
|
+
if return_test_dataset:
|
|
286
|
+
# Split remaining into validation and test
|
|
287
|
+
val_items, test_items = train_test_split(
|
|
288
|
+
remaining,
|
|
289
|
+
test_size=(test_ratio / (val_ratio + test_ratio)),
|
|
290
|
+
random_state=random_state,
|
|
291
|
+
)
|
|
292
|
+
valRunIDs.extend(val_items)
|
|
293
|
+
testRunIDs.extend(test_items)
|
|
294
|
+
else:
|
|
295
|
+
valRunIDs.extend(remaining)
|
|
296
|
+
testRunIDs = []
|
|
297
|
+
else:
|
|
298
|
+
# If no categories, split as a 1D list
|
|
299
|
+
trainRunIDs, remaining = train_test_split(runIDs, test_size=(1 - train_ratio), random_state=random_state)
|
|
300
|
+
if return_test_dataset:
|
|
301
|
+
valRunIDs, testRunIDs = train_test_split(
|
|
302
|
+
remaining,
|
|
303
|
+
test_size=(test_ratio / (val_ratio + test_ratio)),
|
|
304
|
+
random_state=random_state,
|
|
305
|
+
)
|
|
306
|
+
else:
|
|
307
|
+
valRunIDs = remaining
|
|
308
|
+
testRunIDs = []
|
|
309
|
+
|
|
310
|
+
return trainRunIDs, valRunIDs, testRunIDs
|
|
311
|
+
|
|
312
|
+
##############################################################################################################################
|
|
313
|
+
|
|
314
|
+
def load_copick_config(path: str):
|
|
315
|
+
|
|
316
|
+
if os.path.isfile(path):
|
|
317
|
+
root = copick.from_file(path)
|
|
318
|
+
else:
|
|
319
|
+
raise FileNotFoundError(f"Copick Config Path does not exist: {path}")
|
|
320
|
+
|
|
321
|
+
return root
|
|
322
|
+
|
|
323
|
+
##############################################################################################################################
|
|
324
|
+
|
|
325
|
+
# Helper function to flatten and serialize nested parameters
|
|
326
|
+
def flatten_params(params, parent_key=''):
|
|
327
|
+
flattened = {}
|
|
328
|
+
for key, value in params.items():
|
|
329
|
+
new_key = f"{parent_key}.{key}" if parent_key else key
|
|
330
|
+
if isinstance(value, dict):
|
|
331
|
+
flattened.update(flatten_params(value, new_key))
|
|
332
|
+
elif isinstance(value, list):
|
|
333
|
+
flattened[new_key] = ', '.join(map(str, value)) # Convert list to a comma-separated string
|
|
334
|
+
else:
|
|
335
|
+
flattened[new_key] = value
|
|
336
|
+
return flattened
|
|
337
|
+
|
|
338
|
+
# Manually join specific lists into strings for inline display
|
|
339
|
+
def prepare_for_inline_json(data):
|
|
340
|
+
for key in ["trainRunIDs", "valRunIDs", "testRunIDs"]:
|
|
341
|
+
if key in data['dataloader']:
|
|
342
|
+
data['dataloader'][key] = f"[{', '.join(map(repr, data['dataloader'][key]))}]"
|
|
343
|
+
|
|
344
|
+
for key in ['channels', 'strides']:
|
|
345
|
+
if key in data['model']:
|
|
346
|
+
data['model'][key] = f"[{', '.join(map(repr, data['model'][key]))}]"
|
|
347
|
+
return data
|
|
348
|
+
|
|
349
|
+
def get_optimizer_parameters(trainer):
|
|
350
|
+
|
|
351
|
+
optimizer_parameters = {
|
|
352
|
+
'my_num_samples': trainer.num_samples,
|
|
353
|
+
'val_interval': trainer.val_interval,
|
|
354
|
+
'lr': trainer.optimizer.param_groups[0]['lr'],
|
|
355
|
+
'optimizer': trainer.optimizer.__class__.__name__,
|
|
356
|
+
'metrics_function': trainer.metrics_function.__class__.__name__,
|
|
357
|
+
'loss_function': trainer.loss_function.__class__.__name__,
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
# Log Tversky Loss Parameters
|
|
361
|
+
if trainer.loss_function.__class__.__name__ == 'TverskyLoss':
|
|
362
|
+
optimizer_parameters['alpha'] = trainer.loss_function.alpha
|
|
363
|
+
elif trainer.loss_function.__class__.__name__ == 'FocalLoss':
|
|
364
|
+
optimizer_parameters['gamma'] = trainer.loss_function.gamma
|
|
365
|
+
elif trainer.loss_function.__class__.__name__ == 'WeightedFocalTverskyLoss':
|
|
366
|
+
optimizer_parameters['alpha'] = trainer.loss_function.alpha
|
|
367
|
+
optimizer_parameters['gamma'] = trainer.loss_function.gamma
|
|
368
|
+
optimizer_parameters['weight_tversky'] = trainer.loss_function.weight_tversky
|
|
369
|
+
elif trainer.loss_function.__class__.__name__ == 'FocalTverskyLoss':
|
|
370
|
+
optimizer_parameters['alpha'] = trainer.loss_function.alpha
|
|
371
|
+
optimizer_parameters['gamma'] = trainer.loss_function.gamma
|
|
372
|
+
|
|
373
|
+
return optimizer_parameters
|
|
374
|
+
|
|
375
|
+
def save_parameters_to_yaml(model, trainer, dataloader, filename: str):
|
|
376
|
+
|
|
377
|
+
parameters = {
|
|
378
|
+
'model': model.get_model_parameters(),
|
|
379
|
+
'optimizer': get_optimizer_parameters(trainer),
|
|
380
|
+
'dataloader': dataloader.get_dataloader_parameters()
|
|
381
|
+
}
|
|
382
|
+
|
|
383
|
+
utils.save_parameters_yaml(parameters, filename)
|
|
384
|
+
print(f"Training Parameters saved to {filename}")
|
|
385
|
+
|
|
386
|
+
def prepare_inline_results_json(results):
|
|
387
|
+
# Traverse the dictionary and format lists of lists as inline JSON
|
|
388
|
+
for key, value in results.items():
|
|
389
|
+
# Check if the value is a list of lists (like [[epoch, value], ...])
|
|
390
|
+
if isinstance(value, list) and all(isinstance(item, list) and len(item) == 2 for item in value):
|
|
391
|
+
# Format the list of lists as a single-line JSON string
|
|
392
|
+
results[key] = json.dumps(value)
|
|
393
|
+
return results
|
|
394
|
+
|
|
395
|
+
# Check to See if I'm Happy with This... Maybe Save as H5 File?
|
|
396
|
+
def save_results_to_json(results, filename: str):
|
|
397
|
+
|
|
398
|
+
results = prepare_inline_results_json(results)
|
|
399
|
+
with open(os.path.join(filename), "w") as json_file:
|
|
400
|
+
json.dump( results, json_file, indent=4 )
|
|
401
|
+
print(f"Training Results saved to {filename}")
|
|
402
|
+
|
|
403
|
+
##############################################################################################################################
|
|
404
|
+
|
|
405
|
+
# def save_parameters_to_json(model, trainer, dataloader, filename: str):
|
|
406
|
+
|
|
407
|
+
# parameters = {
|
|
408
|
+
# 'model': model.get_model_parameters(),
|
|
409
|
+
# 'optimizer': get_optimizer_parameters(trainer),
|
|
410
|
+
# 'dataloader': dataloader.get_dataloader_parameters()
|
|
411
|
+
# }
|
|
412
|
+
# parameters = prepare_for_inline_json(parameters)
|
|
413
|
+
|
|
414
|
+
# with open(os.path.join(filename), "w") as json_file:
|
|
415
|
+
# json.dump( parameters, json_file, indent=4 )
|
|
416
|
+
# print(f"Training Parameters saved to {filename}")
|
|
417
|
+
|
|
418
|
+
# def split_datasets(runIDs,
|
|
419
|
+
# train_ratio: float = 0.7,
|
|
420
|
+
# val_ratio: float = 0.15,
|
|
421
|
+
# test_ratio: float = 0.15,
|
|
422
|
+
# return_test_dataset: bool = True,
|
|
423
|
+
# random_state: int = 42):
|
|
424
|
+
# """
|
|
425
|
+
# Splits a given dataset into three subsets: training, validation, and testing. The proportions
|
|
426
|
+
# of each subset are determined by the provided ratios, ensuring that they add up to 1. The
|
|
427
|
+
# function uses a fixed random state for reproducibility.
|
|
428
|
+
|
|
429
|
+
# Parameters:
|
|
430
|
+
# - runIDs: The complete dataset that needs to be split.
|
|
431
|
+
# - train_ratio: The proportion of the dataset to be used for training.
|
|
432
|
+
# - val_ratio: The proportion of the dataset to be used for validation.
|
|
433
|
+
# - test_ratio: The proportion of the dataset to be used for testing.
|
|
434
|
+
|
|
435
|
+
# Returns:
|
|
436
|
+
# - trainRunIDs: The subset of the dataset used for training.
|
|
437
|
+
# - valRunIDs: The subset of the dataset used for validation.
|
|
438
|
+
# - testRunIDs: The subset of the dataset used for testing.
|
|
439
|
+
# """
|
|
440
|
+
|
|
441
|
+
# # Ensure the ratios add up to 1
|
|
442
|
+
# assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must add up to 1.0"
|
|
443
|
+
|
|
444
|
+
# # First, split into train and remaining (30%)
|
|
445
|
+
# trainRunIDs, valRunIDs = train_test_split(runIDs, test_size=(1 - train_ratio), random_state=random_state)
|
|
446
|
+
|
|
447
|
+
# # (Optional) split the remaining into validation and test
|
|
448
|
+
# if return_test_dataset:
|
|
449
|
+
# valRunIDs, testRunIDs = train_test_split(
|
|
450
|
+
# valRunIDs,
|
|
451
|
+
# test_size=(test_ratio / (val_ratio + test_ratio)),
|
|
452
|
+
# random_state=random_state,
|
|
453
|
+
# )
|
|
454
|
+
# else:
|
|
455
|
+
# testRunIDs = None
|
|
456
|
+
|
|
457
|
+
# return trainRunIDs, valRunIDs, testRunIDs
|
octopi/losses.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
from monai.losses import FocalLoss, TverskyLoss
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
class WeightedFocalTverskyLoss(torch.nn.Module):
|
|
5
|
+
def __init__(
|
|
6
|
+
self, gamma=1.0, alpha=0.7, beta=0.3,
|
|
7
|
+
weight_tversky=0.5, weight_focal=0.5,
|
|
8
|
+
smooth=1e-5, **kwargs ):
|
|
9
|
+
"""
|
|
10
|
+
Weighted combination of Focal and Tversky loss.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
gamma (float): Focus parameter for Focal Loss.
|
|
14
|
+
alpha (float): Weight for false positives in Tversky Loss.
|
|
15
|
+
beta (float): Weight for false negatives in Tversky Loss.
|
|
16
|
+
weight_tversky (float): Weight of Tversky loss in the combination.
|
|
17
|
+
weight_focal (float): Weight of Focal loss in the combination.
|
|
18
|
+
smooth (float): Smoothing factor to avoid division by zero.
|
|
19
|
+
"""
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.tversky_loss = TverskyLoss(
|
|
22
|
+
alpha=alpha, beta=beta, include_background=True,
|
|
23
|
+
to_onehot_y=True, softmax=True,
|
|
24
|
+
smooth_nr=smooth, smooth_dr=smooth, **kwargs
|
|
25
|
+
)
|
|
26
|
+
self.focal_loss = FocalLoss(
|
|
27
|
+
include_background=True, to_onehot_y=True,
|
|
28
|
+
use_softmax=True, gamma=gamma
|
|
29
|
+
)
|
|
30
|
+
self.alpha = alpha
|
|
31
|
+
self.beta = beta
|
|
32
|
+
self.gamma = gamma
|
|
33
|
+
self.weight_tversky = weight_tversky
|
|
34
|
+
self.weight_focal = weight_focal
|
|
35
|
+
|
|
36
|
+
def forward(self, y_pred, y_true):
|
|
37
|
+
"""
|
|
38
|
+
Compute the combined loss.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
y_pred (Tensor): Predicted probabilities (B, C, ...).
|
|
42
|
+
y_true (Tensor): Ground truth labels (B, C, ...), one-hot encoded.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
Tensor: Weighted combination of Tversky and Focal losses.
|
|
46
|
+
"""
|
|
47
|
+
tversky = self.tversky_loss(y_pred, y_true)
|
|
48
|
+
focal = self.focal_loss(y_pred, y_true)
|
|
49
|
+
return self.weight_tversky * tversky + self.weight_focal * focal
|
|
50
|
+
|
|
51
|
+
class FocalTverskyLoss(TverskyLoss):
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
alpha=0.7, beta=0.3, gamma=1.0, smooth=1e-5, **kwargs):
|
|
55
|
+
"""
|
|
56
|
+
Focal Tversky Loss with an additional power term for harder samples.
|
|
57
|
+
|
|
58
|
+
From https://arxiv.org/abs/1810.07842
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
alpha (float): Weight for false positives.
|
|
62
|
+
beta (float): Weight for false negatives.
|
|
63
|
+
gamma (float): Focus parameter (like Focal Loss).
|
|
64
|
+
smooth (float): Smoothing factor to avoid division by zero.
|
|
65
|
+
"""
|
|
66
|
+
super().__init__(
|
|
67
|
+
alpha=alpha, beta=beta,
|
|
68
|
+
include_background=True,
|
|
69
|
+
to_onehot_y=True, softmax=True,
|
|
70
|
+
smooth_nr=smooth, smooth_dr=smooth, **kwargs)
|
|
71
|
+
self.gamma = gamma
|
|
72
|
+
self.alpha = alpha
|
|
73
|
+
self.beta = beta
|
|
74
|
+
|
|
75
|
+
def forward(self, y_pred, y_true):
|
|
76
|
+
"""
|
|
77
|
+
Args:
|
|
78
|
+
y_pred (Tensor): Predicted probabilities (B, C, ...).
|
|
79
|
+
y_true (Tensor): Ground truth labels (B, C, ...), one-hot encoded.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Tensor: Loss value.
|
|
83
|
+
"""
|
|
84
|
+
tversky_loss = super().forward(y_pred, y_true)
|
|
85
|
+
modified_loss = torch.pow(tversky_loss, 1 / self.gamma)
|
|
86
|
+
return modified_loss
|
octopi/main.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
from octopi.processing.importers import cli_dataportal as download_dataportal
|
|
2
|
+
from octopi.processing.importers import cli_mrcs as import_mrc_volumes
|
|
3
|
+
from octopi.entry_points.run_create_targets import cli as create_targets
|
|
4
|
+
from octopi.entry_points.run_train import cli as train_model
|
|
5
|
+
from octopi.entry_points.run_optuna import cli as model_explore
|
|
6
|
+
from octopi.entry_points.run_segment_predict import cli as inference
|
|
7
|
+
from octopi.entry_points.run_localize import cli as localize
|
|
8
|
+
from octopi.entry_points.run_evaluate import cli as evaluate
|
|
9
|
+
from octopi.entry_points.run_extract_mb_picks import cli as extract_mb_picks
|
|
10
|
+
import octopi.entry_points.create_slurm_submission as slurm_submitter
|
|
11
|
+
import argparse
|
|
12
|
+
import sys
|
|
13
|
+
|
|
14
|
+
def cli_main():
|
|
15
|
+
"""
|
|
16
|
+
Main CLI entry point for octopi.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
# Create the main parser
|
|
20
|
+
parser = argparse.ArgumentParser(
|
|
21
|
+
description="Octopi 🐙: 🛠️ Tools for Finding Proteins in 🧊 cryo-ET data",
|
|
22
|
+
formatter_class=argparse.RawDescriptionHelpFormatter
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
# Create subparsers
|
|
26
|
+
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
|
27
|
+
subparsers.required = True # Make subcommand required
|
|
28
|
+
|
|
29
|
+
# Define all subcommands with their help text
|
|
30
|
+
commands = {
|
|
31
|
+
"import-mrc-volumes": (import_mrc_volumes, "Import MRC volumes from a directory, we can downsample to smaller voxel size if desired."),
|
|
32
|
+
"download-dataportal": (download_dataportal, "Download tomograms from the Dataportal, we can downsample to smaller voxel size if desired."),
|
|
33
|
+
"create-targets": (create_targets, "Generate segmentation targets from coordinates."),
|
|
34
|
+
"train": (train_model, "Train a single U-Net model."),
|
|
35
|
+
"model-explore": (model_explore, "Explore model architectures with Optuna / Bayesian Optimization."),
|
|
36
|
+
"inference": (inference, "Perform segmentation inference on tomograms."),
|
|
37
|
+
"localize": (localize, "Perform localization of particles in tomograms."),
|
|
38
|
+
"extract-mb-picks": (extract_mb_picks, "Extract MB Picks from tomograms."),
|
|
39
|
+
"evaluate": (evaluate, "Evaluate the performance of a model."),
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
# Add all subparsers and their help text
|
|
43
|
+
for cmd_name, (cmd_func, cmd_help) in commands.items():
|
|
44
|
+
subparsers.add_parser(cmd_name, help=cmd_help)
|
|
45
|
+
|
|
46
|
+
# Parse just the command part to determine which subcommand was chosen
|
|
47
|
+
if len(sys.argv) > 1 and sys.argv[1] in commands:
|
|
48
|
+
command = sys.argv[1]
|
|
49
|
+
cmd_func = commands[command][0]
|
|
50
|
+
|
|
51
|
+
# Remove the first argument (command name) and call the appropriate CLI function
|
|
52
|
+
sys.argv = [sys.argv[0]] + sys.argv[2:]
|
|
53
|
+
cmd_func()
|
|
54
|
+
else:
|
|
55
|
+
# Just show help if no valid command
|
|
56
|
+
parser.parse_args()
|
|
57
|
+
|
|
58
|
+
def cli_slurm_main():
|
|
59
|
+
"""
|
|
60
|
+
SLURM-specific CLI entry point for octopi.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
# Create the main parser
|
|
64
|
+
parser = argparse.ArgumentParser(
|
|
65
|
+
description="Octopi for SLURM 🖥️: Shell Submission Tools for Running 🐙 on HPC",
|
|
66
|
+
formatter_class=argparse.RawDescriptionHelpFormatter
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# Create subparsers
|
|
70
|
+
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
|
71
|
+
subparsers.required = True # Make subcommand required
|
|
72
|
+
|
|
73
|
+
# Define all subcommands with their help text
|
|
74
|
+
commands = {
|
|
75
|
+
"import-mrc-volumes": (slurm_submitter.import_mrc_slurm, "Import MRC volumes from a directory."),
|
|
76
|
+
"download-dataportal": (slurm_submitter.download_dataportal_slurm, "Download tomograms from the Dataportal, we can downsample to smaller voxel size if desired."),
|
|
77
|
+
# "create-targets": (create_targets, "Generate segmentation targets from coordinates."),
|
|
78
|
+
"train": (slurm_submitter.train_model_slurm, "Train a single U-Net model."),
|
|
79
|
+
"model-explore": (slurm_submitter.model_explore_slurm, "Explore model architectures with Optuna / Bayesian Optimization."),
|
|
80
|
+
"inference": (slurm_submitter.inference_slurm, "Perform segmentation inference on tomograms."),
|
|
81
|
+
"localize": (slurm_submitter.localize_slurm, "Perform localization of particles in tomograms."),
|
|
82
|
+
# "extract-mb-picks": (extract_mb_picks, "Extract MB Picks from tomograms.")
|
|
83
|
+
# "evaluate": (evaluate, "Evaluate the performance of a model."),
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
# Add all subparsers and their help text
|
|
87
|
+
for cmd_name, (cmd_func, cmd_help) in commands.items():
|
|
88
|
+
subparsers.add_parser(cmd_name, help=cmd_help)
|
|
89
|
+
|
|
90
|
+
# Parse just the command part to determine which subcommand was chosen
|
|
91
|
+
if len(sys.argv) > 1 and sys.argv[1] in commands:
|
|
92
|
+
command = sys.argv[1]
|
|
93
|
+
cmd_func = commands[command][0]
|
|
94
|
+
|
|
95
|
+
# Remove the first argument (command name) and call the appropriate CLI function
|
|
96
|
+
sys.argv = [sys.argv[0]] + sys.argv[2:]
|
|
97
|
+
cmd_func()
|
|
98
|
+
else:
|
|
99
|
+
# Just show help if no valid command
|
|
100
|
+
parser.parse_args()
|
|
101
|
+
|