octopi 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of octopi might be problematic. Click here for more details.

Files changed (59) hide show
  1. octopi/__init__.py +0 -0
  2. octopi/datasets/__init__.py +0 -0
  3. octopi/datasets/augment.py +84 -0
  4. octopi/datasets/cached_datset.py +113 -0
  5. octopi/datasets/dataset.py +19 -0
  6. octopi/datasets/generators.py +429 -0
  7. octopi/datasets/mixup.py +49 -0
  8. octopi/datasets/multi_config_generator.py +253 -0
  9. octopi/entry_points/__init__.py +0 -0
  10. octopi/entry_points/common.py +80 -0
  11. octopi/entry_points/create_slurm_submission.py +243 -0
  12. octopi/entry_points/run_create_targets.py +281 -0
  13. octopi/entry_points/run_evaluate.py +65 -0
  14. octopi/entry_points/run_extract_mb_picks.py +141 -0
  15. octopi/entry_points/run_extract_midpoint.py +143 -0
  16. octopi/entry_points/run_localize.py +222 -0
  17. octopi/entry_points/run_optuna.py +139 -0
  18. octopi/entry_points/run_segment_predict.py +166 -0
  19. octopi/entry_points/run_train.py +201 -0
  20. octopi/extract/__init__.py +0 -0
  21. octopi/extract/localize.py +254 -0
  22. octopi/extract/membranebound_extract.py +262 -0
  23. octopi/extract/midpoint_extract.py +193 -0
  24. octopi/io.py +457 -0
  25. octopi/losses.py +86 -0
  26. octopi/main.py +101 -0
  27. octopi/models/AttentionUnet.py +56 -0
  28. octopi/models/MedNeXt.py +111 -0
  29. octopi/models/ModelTemplate.py +36 -0
  30. octopi/models/SegResNet.py +92 -0
  31. octopi/models/Unet.py +59 -0
  32. octopi/models/UnetPlusPlus.py +47 -0
  33. octopi/models/__init__.py +0 -0
  34. octopi/models/common.py +62 -0
  35. octopi/processing/__init__.py +0 -0
  36. octopi/processing/create_targets_from_picks.py +106 -0
  37. octopi/processing/downsample.py +129 -0
  38. octopi/processing/evaluate.py +289 -0
  39. octopi/processing/importers.py +213 -0
  40. octopi/processing/my_metrics.py +26 -0
  41. octopi/processing/segmentation_from_picks.py +167 -0
  42. octopi/processing/writers.py +102 -0
  43. octopi/pytorch/__init__.py +0 -0
  44. octopi/pytorch/hyper_search.py +243 -0
  45. octopi/pytorch/model_search_submitter.py +290 -0
  46. octopi/pytorch/segmentation.py +317 -0
  47. octopi/pytorch/trainer.py +438 -0
  48. octopi/pytorch_lightning/__init__.py +0 -0
  49. octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
  50. octopi/pytorch_lightning/train_pl.py +244 -0
  51. octopi/stopping_criteria.py +143 -0
  52. octopi/submit_slurm.py +95 -0
  53. octopi/utils.py +238 -0
  54. octopi/visualization_tools.py +201 -0
  55. octopi-1.0.dist-info/LICENSE +41 -0
  56. octopi-1.0.dist-info/METADATA +209 -0
  57. octopi-1.0.dist-info/RECORD +59 -0
  58. octopi-1.0.dist-info/WHEEL +4 -0
  59. octopi-1.0.dist-info/entry_points.txt +4 -0
octopi/io.py ADDED
@@ -0,0 +1,457 @@
1
+ from monai.data import DataLoader, CacheDataset, Dataset
2
+ from monai.transforms import (
3
+ Compose,
4
+ NormalizeIntensityd,
5
+ EnsureChannelFirstd,
6
+ )
7
+ from sklearn.model_selection import train_test_split
8
+ import copick, torch, os, json, random
9
+ from collections import defaultdict
10
+ from octopi import utils
11
+ from typing import List
12
+ from tqdm import tqdm
13
+ import numpy as np
14
+
15
+ ##############################################################################################################################
16
+
17
+ def load_training_data(root,
18
+ runIDs: List[str],
19
+ voxel_spacing: float,
20
+ tomo_algorithm: str,
21
+ segmenation_name: str,
22
+ segmentation_session_id: str = None,
23
+ segmentation_user_id: str = None,
24
+ progress_update: bool = True):
25
+
26
+ data_dicts = []
27
+ # Use tqdm for progress tracking only if progress_update is True
28
+ iterable = tqdm(runIDs, desc="Loading Training Data") if progress_update else runIDs
29
+ for runID in iterable:
30
+ run = root.get_run(str(runID))
31
+ tomogram = get_tomogram_array(run, voxel_spacing, tomo_algorithm)
32
+ segmentation = get_segmentation_array(run,
33
+ voxel_spacing,
34
+ segmenation_name,
35
+ segmentation_session_id,
36
+ segmentation_user_id)
37
+ data_dicts.append({"image": tomogram, "label": segmentation})
38
+
39
+ return data_dicts
40
+
41
+ ##############################################################################################################################
42
+
43
+ def load_predict_data(root,
44
+ runIDs: List[str],
45
+ voxel_spacing: float,
46
+ tomo_algorithm: str):
47
+
48
+ data_dicts = []
49
+ for runID in tqdm(runIDs):
50
+ run = root.get_run(str(runID))
51
+ tomogram = get_tomogram_array(run, voxel_spacing, tomo_algorithm)
52
+ data_dicts.append({"image": tomogram})
53
+
54
+ return data_dicts
55
+
56
+ ##############################################################################################################################
57
+
58
+ def create_predict_dataloader(
59
+ root,
60
+ voxel_spacing: float,
61
+ tomo_algorithm: str,
62
+ runIDs: str = None,
63
+ ):
64
+
65
+ # define pre transforms
66
+ pre_transforms = Compose(
67
+ [ EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
68
+ NormalizeIntensityd(keys=["image"]),
69
+ ])
70
+
71
+ # Split trainRunIDs, validateRunIDs, testRunIDs
72
+ if runIDs is None:
73
+ runIDs = [run.name for run in root.runs]
74
+ test_files = load_predict_data(root, runIDs, voxel_spacing, tomo_algorithm)
75
+
76
+ bs = min( len(test_files), 4)
77
+ test_ds = CacheDataset(data=test_files, transform=pre_transforms)
78
+ test_loader = DataLoader(test_ds,
79
+ batch_size=bs,
80
+ shuffle=False,
81
+ num_workers=4,
82
+ pin_memory=torch.cuda.is_available())
83
+ return test_loader, test_ds
84
+
85
+ ##############################################################################################################################
86
+
87
+ def get_tomogram_array(run,
88
+ voxel_size: float = 10,
89
+ tomo_type: str = 'wbp',
90
+ raise_error: bool = True):
91
+
92
+ voxel_spacing_obj = run.get_voxel_spacing(voxel_size)
93
+
94
+ if voxel_spacing_obj is None:
95
+ # Query Avaiable Voxel Spacings
96
+ availableVoxelSpacings = [tomo.voxel_size for tomo in run.voxel_spacings]
97
+
98
+ # Report to the user which voxel spacings they can use
99
+ message = (f"\n[Warning] No tomogram found for {run.name} with voxel size {voxel_size} and tomogram type {tomo_type}"
100
+ f"\nAvailable spacings are: {', '.join(map(str, availableVoxelSpacings))}\n" )
101
+ if raise_error:
102
+ raise ValueError(message)
103
+ else:
104
+ print(message)
105
+ return None
106
+
107
+ tomogram = voxel_spacing_obj.get_tomogram(tomo_type)
108
+ if tomogram is None:
109
+ # Get available algorithms
110
+ availableAlgorithms = [tomo.tomo_type for tomo in run.get_voxel_spacing(voxel_size).tomograms]
111
+
112
+ # Report to the user which algorithms are available
113
+ message = (f"\n[Warning] No tomogram found for {run.name} with voxel size {voxel_size} and tomogram type {tomo_type}"
114
+ f"\nAvailable algorithms are: {', '.join(availableAlgorithms)}\n")
115
+ if raise_error:
116
+ raise ValueError(message)
117
+ else:
118
+ print(message)
119
+ return None
120
+
121
+ return tomogram.numpy().astype(np.float32)
122
+
123
+ ##############################################################################################################################
124
+
125
+ def get_segmentation_array(run,
126
+ voxel_spacing: float,
127
+ segmentation_name: str,
128
+ session_id=None,
129
+ user_id=None,
130
+ raise_error: bool = True):
131
+
132
+ seg = run.get_segmentations(name=segmentation_name,
133
+ session_id = session_id,
134
+ user_id = user_id,
135
+ voxel_size=float(voxel_spacing))
136
+
137
+ # No Segmentations Are Available, Result in Error
138
+ if len(seg) == 0:
139
+ # Get all available segmentations with their metadata
140
+ available_segs = run.get_segmentations(voxel_size=voxel_spacing)
141
+ seg_info = [(s.name, s.user_id, s.session_id) for s in available_segs]
142
+
143
+ # Format the information for display
144
+ seg_details = [f"(name: {name}, user_id: {uid}, session_id: {sid})"
145
+ for name, uid, sid in seg_info]
146
+
147
+ message = ( f'\nNo segmentation found matching:\n'
148
+ f' name: {segmentation_name}, user_id: {user_id}, session_id: {session_id}\n'
149
+ f'Available segmentations in {run.name} are:\n ' +
150
+ '\n '.join(seg_details) )
151
+ if raise_error:
152
+ raise ValueError(message)
153
+ else:
154
+ print(message)
155
+ return None
156
+
157
+ # No Segmentations Are Available, Result in Error
158
+ if len(seg) > 1:
159
+ print(f'[Warning] More Than 1 Segmentation is Available for the Query Information. '
160
+ f'Available Segmentations are: {seg} '
161
+ f'Defaulting to Loading: {seg[0]}\n')
162
+ seg = seg[0]
163
+
164
+ return seg.numpy().astype(np.int8)
165
+
166
+ ##############################################################################################################################
167
+
168
+ def get_copick_coordinates(run, # CoPick run object containing the segmentation data
169
+ name: str, # Name of the object or protein for which coordinates are being extracted
170
+ user_id: str, # Identifier of the user that generated the picks
171
+ session_id: str = None, # Identifier of the session that generated the picks
172
+ voxel_size: float = 10, # Voxel size of the tomogram, used for scaling the coordinates
173
+ raise_error: bool = True):
174
+
175
+ # Retrieve the pick points associated with the specified object and user ID
176
+ picks = run.get_picks(object_name=name, user_id=user_id, session_id=session_id)
177
+
178
+ if len(picks) == 0:
179
+ # Get all available segmentations with their metadata
180
+
181
+ available_picks = run.get_picks()
182
+ picks_info = [(s.pickable_object_name, s.user_id, s.session_id) for s in available_picks]
183
+
184
+ # Format the information for display
185
+ picks_details = [f"(name: {name}, user_id: {uid}, session_id: {sid})"
186
+ for name, uid, sid in picks_info]
187
+
188
+ message = ( f'\nNo picks found matching:\n'
189
+ f' name: {name}, user_id: {user_id}, session_id: {session_id}\n'
190
+ f'Available picks are:\n '
191
+ + '\n '.join(picks_details) )
192
+ if raise_error:
193
+ raise ValueError(message)
194
+ else:
195
+ print(message)
196
+ return None
197
+ elif len(picks) > 1:
198
+ # Format pick information for display
199
+ picks_info = [(p.pickable_object_name, p.user_id, p.session_id) for p in picks]
200
+ picks_details = [f"(name: {name}, user_id: {uid}, session_id: {sid})"
201
+ for name, uid, sid in picks_info]
202
+
203
+ print(f'[Warning] More than 1 pick is available for the query information.'
204
+ f'\nAvailable picks are:\n ' +
205
+ '\n '.join(picks_details) +
206
+ f'\nDefaulting to loading:\n {picks[0]}\n')
207
+ points = picks[0].points
208
+
209
+ # Initialize an array to store the coordinates
210
+ nPoints = len(picks[0].points) # Number of points retrieved
211
+ coordinates = np.zeros([len(picks[0].points), 3]) # Create an empty array to hold the (z, y, x) coordinates
212
+
213
+ # Iterate over all points and convert their locations to coordinates in voxel space
214
+ for ii in range(nPoints):
215
+ coordinates[ii,] = [points[ii].location.z / voxel_size, # Scale z-coordinate by voxel size
216
+ points[ii].location.y / voxel_size, # Scale y-coordinate by voxel size
217
+ points[ii].location.x / voxel_size] # Scale x-coordinate by voxel size
218
+
219
+ # Return the array of coordinates
220
+ return coordinates
221
+
222
+
223
+ ##############################################################################################################################
224
+
225
+ def adjust_to_multiple(value, multiple = 16):
226
+ return int((value // multiple) * multiple)
227
+
228
+ def get_input_dimensions(dataset, crop_size: int):
229
+ nx = dataset[0]['image'].shape[1]
230
+ if crop_size > nx:
231
+ first_dim = adjust_to_multiple(nx/2)
232
+ return first_dim, crop_size, crop_size
233
+ else:
234
+ return crop_size, crop_size, crop_size
235
+
236
+ def get_num_classes(copick_config_path: str):
237
+
238
+ root = copick.from_file(copick_config_path)
239
+ return len(root.pickable_objects) + 1
240
+
241
+ def split_multiclass_dataset(runIDs,
242
+ train_ratio: float = 0.7,
243
+ val_ratio: float = 0.15,
244
+ test_ratio: float = 0.15,
245
+ return_test_dataset: bool = True,
246
+ random_state: int = 42):
247
+ """
248
+ Splits a given dataset into three subsets: training, validation, and testing. If the dataset
249
+ has categories (as tuples), splits are balanced across all categories. If the dataset is a 1D
250
+ list, it is split without categorization.
251
+
252
+ Parameters:
253
+ - runIDs: A list of items to split. It can be a 1D list or a list of tuples (category, value).
254
+ - train_ratio: Proportion of the dataset for training.
255
+ - val_ratio: Proportion of the dataset for validation.
256
+ - test_ratio: Proportion of the dataset for testing.
257
+ - return_test_dataset: Whether to return the test dataset.
258
+ - random_state: Random state for reproducibility.
259
+
260
+ Returns:
261
+ - trainRunIDs: Training subset.
262
+ - valRunIDs: Validation subset.
263
+ - testRunIDs: Testing subset (if return_test_dataset is True, otherwise None).
264
+ """
265
+
266
+ # Ensure the ratios add up to 1
267
+ assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must add up to 1.0"
268
+
269
+ # Check if the dataset has categories
270
+ if isinstance(runIDs[0], tuple) and len(runIDs[0]) == 2:
271
+ # Group by category
272
+ grouped = defaultdict(list)
273
+ for item in runIDs:
274
+ grouped[item[0]].append(item)
275
+
276
+ # Split each category
277
+ trainRunIDs, valRunIDs, testRunIDs = [], [], []
278
+ for category, items in grouped.items():
279
+ # Shuffle for randomness
280
+ random.shuffle(items)
281
+ # Split into train and remaining
282
+ train_items, remaining = train_test_split(items, test_size=(1 - train_ratio), random_state=random_state)
283
+ trainRunIDs.extend(train_items)
284
+
285
+ if return_test_dataset:
286
+ # Split remaining into validation and test
287
+ val_items, test_items = train_test_split(
288
+ remaining,
289
+ test_size=(test_ratio / (val_ratio + test_ratio)),
290
+ random_state=random_state,
291
+ )
292
+ valRunIDs.extend(val_items)
293
+ testRunIDs.extend(test_items)
294
+ else:
295
+ valRunIDs.extend(remaining)
296
+ testRunIDs = []
297
+ else:
298
+ # If no categories, split as a 1D list
299
+ trainRunIDs, remaining = train_test_split(runIDs, test_size=(1 - train_ratio), random_state=random_state)
300
+ if return_test_dataset:
301
+ valRunIDs, testRunIDs = train_test_split(
302
+ remaining,
303
+ test_size=(test_ratio / (val_ratio + test_ratio)),
304
+ random_state=random_state,
305
+ )
306
+ else:
307
+ valRunIDs = remaining
308
+ testRunIDs = []
309
+
310
+ return trainRunIDs, valRunIDs, testRunIDs
311
+
312
+ ##############################################################################################################################
313
+
314
+ def load_copick_config(path: str):
315
+
316
+ if os.path.isfile(path):
317
+ root = copick.from_file(path)
318
+ else:
319
+ raise FileNotFoundError(f"Copick Config Path does not exist: {path}")
320
+
321
+ return root
322
+
323
+ ##############################################################################################################################
324
+
325
+ # Helper function to flatten and serialize nested parameters
326
+ def flatten_params(params, parent_key=''):
327
+ flattened = {}
328
+ for key, value in params.items():
329
+ new_key = f"{parent_key}.{key}" if parent_key else key
330
+ if isinstance(value, dict):
331
+ flattened.update(flatten_params(value, new_key))
332
+ elif isinstance(value, list):
333
+ flattened[new_key] = ', '.join(map(str, value)) # Convert list to a comma-separated string
334
+ else:
335
+ flattened[new_key] = value
336
+ return flattened
337
+
338
+ # Manually join specific lists into strings for inline display
339
+ def prepare_for_inline_json(data):
340
+ for key in ["trainRunIDs", "valRunIDs", "testRunIDs"]:
341
+ if key in data['dataloader']:
342
+ data['dataloader'][key] = f"[{', '.join(map(repr, data['dataloader'][key]))}]"
343
+
344
+ for key in ['channels', 'strides']:
345
+ if key in data['model']:
346
+ data['model'][key] = f"[{', '.join(map(repr, data['model'][key]))}]"
347
+ return data
348
+
349
+ def get_optimizer_parameters(trainer):
350
+
351
+ optimizer_parameters = {
352
+ 'my_num_samples': trainer.num_samples,
353
+ 'val_interval': trainer.val_interval,
354
+ 'lr': trainer.optimizer.param_groups[0]['lr'],
355
+ 'optimizer': trainer.optimizer.__class__.__name__,
356
+ 'metrics_function': trainer.metrics_function.__class__.__name__,
357
+ 'loss_function': trainer.loss_function.__class__.__name__,
358
+ }
359
+
360
+ # Log Tversky Loss Parameters
361
+ if trainer.loss_function.__class__.__name__ == 'TverskyLoss':
362
+ optimizer_parameters['alpha'] = trainer.loss_function.alpha
363
+ elif trainer.loss_function.__class__.__name__ == 'FocalLoss':
364
+ optimizer_parameters['gamma'] = trainer.loss_function.gamma
365
+ elif trainer.loss_function.__class__.__name__ == 'WeightedFocalTverskyLoss':
366
+ optimizer_parameters['alpha'] = trainer.loss_function.alpha
367
+ optimizer_parameters['gamma'] = trainer.loss_function.gamma
368
+ optimizer_parameters['weight_tversky'] = trainer.loss_function.weight_tversky
369
+ elif trainer.loss_function.__class__.__name__ == 'FocalTverskyLoss':
370
+ optimizer_parameters['alpha'] = trainer.loss_function.alpha
371
+ optimizer_parameters['gamma'] = trainer.loss_function.gamma
372
+
373
+ return optimizer_parameters
374
+
375
+ def save_parameters_to_yaml(model, trainer, dataloader, filename: str):
376
+
377
+ parameters = {
378
+ 'model': model.get_model_parameters(),
379
+ 'optimizer': get_optimizer_parameters(trainer),
380
+ 'dataloader': dataloader.get_dataloader_parameters()
381
+ }
382
+
383
+ utils.save_parameters_yaml(parameters, filename)
384
+ print(f"Training Parameters saved to {filename}")
385
+
386
+ def prepare_inline_results_json(results):
387
+ # Traverse the dictionary and format lists of lists as inline JSON
388
+ for key, value in results.items():
389
+ # Check if the value is a list of lists (like [[epoch, value], ...])
390
+ if isinstance(value, list) and all(isinstance(item, list) and len(item) == 2 for item in value):
391
+ # Format the list of lists as a single-line JSON string
392
+ results[key] = json.dumps(value)
393
+ return results
394
+
395
+ # Check to See if I'm Happy with This... Maybe Save as H5 File?
396
+ def save_results_to_json(results, filename: str):
397
+
398
+ results = prepare_inline_results_json(results)
399
+ with open(os.path.join(filename), "w") as json_file:
400
+ json.dump( results, json_file, indent=4 )
401
+ print(f"Training Results saved to {filename}")
402
+
403
+ ##############################################################################################################################
404
+
405
+ # def save_parameters_to_json(model, trainer, dataloader, filename: str):
406
+
407
+ # parameters = {
408
+ # 'model': model.get_model_parameters(),
409
+ # 'optimizer': get_optimizer_parameters(trainer),
410
+ # 'dataloader': dataloader.get_dataloader_parameters()
411
+ # }
412
+ # parameters = prepare_for_inline_json(parameters)
413
+
414
+ # with open(os.path.join(filename), "w") as json_file:
415
+ # json.dump( parameters, json_file, indent=4 )
416
+ # print(f"Training Parameters saved to {filename}")
417
+
418
+ # def split_datasets(runIDs,
419
+ # train_ratio: float = 0.7,
420
+ # val_ratio: float = 0.15,
421
+ # test_ratio: float = 0.15,
422
+ # return_test_dataset: bool = True,
423
+ # random_state: int = 42):
424
+ # """
425
+ # Splits a given dataset into three subsets: training, validation, and testing. The proportions
426
+ # of each subset are determined by the provided ratios, ensuring that they add up to 1. The
427
+ # function uses a fixed random state for reproducibility.
428
+
429
+ # Parameters:
430
+ # - runIDs: The complete dataset that needs to be split.
431
+ # - train_ratio: The proportion of the dataset to be used for training.
432
+ # - val_ratio: The proportion of the dataset to be used for validation.
433
+ # - test_ratio: The proportion of the dataset to be used for testing.
434
+
435
+ # Returns:
436
+ # - trainRunIDs: The subset of the dataset used for training.
437
+ # - valRunIDs: The subset of the dataset used for validation.
438
+ # - testRunIDs: The subset of the dataset used for testing.
439
+ # """
440
+
441
+ # # Ensure the ratios add up to 1
442
+ # assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must add up to 1.0"
443
+
444
+ # # First, split into train and remaining (30%)
445
+ # trainRunIDs, valRunIDs = train_test_split(runIDs, test_size=(1 - train_ratio), random_state=random_state)
446
+
447
+ # # (Optional) split the remaining into validation and test
448
+ # if return_test_dataset:
449
+ # valRunIDs, testRunIDs = train_test_split(
450
+ # valRunIDs,
451
+ # test_size=(test_ratio / (val_ratio + test_ratio)),
452
+ # random_state=random_state,
453
+ # )
454
+ # else:
455
+ # testRunIDs = None
456
+
457
+ # return trainRunIDs, valRunIDs, testRunIDs
octopi/losses.py ADDED
@@ -0,0 +1,86 @@
1
+ from monai.losses import FocalLoss, TverskyLoss
2
+ import torch
3
+
4
+ class WeightedFocalTverskyLoss(torch.nn.Module):
5
+ def __init__(
6
+ self, gamma=1.0, alpha=0.7, beta=0.3,
7
+ weight_tversky=0.5, weight_focal=0.5,
8
+ smooth=1e-5, **kwargs ):
9
+ """
10
+ Weighted combination of Focal and Tversky loss.
11
+
12
+ Args:
13
+ gamma (float): Focus parameter for Focal Loss.
14
+ alpha (float): Weight for false positives in Tversky Loss.
15
+ beta (float): Weight for false negatives in Tversky Loss.
16
+ weight_tversky (float): Weight of Tversky loss in the combination.
17
+ weight_focal (float): Weight of Focal loss in the combination.
18
+ smooth (float): Smoothing factor to avoid division by zero.
19
+ """
20
+ super().__init__()
21
+ self.tversky_loss = TverskyLoss(
22
+ alpha=alpha, beta=beta, include_background=True,
23
+ to_onehot_y=True, softmax=True,
24
+ smooth_nr=smooth, smooth_dr=smooth, **kwargs
25
+ )
26
+ self.focal_loss = FocalLoss(
27
+ include_background=True, to_onehot_y=True,
28
+ use_softmax=True, gamma=gamma
29
+ )
30
+ self.alpha = alpha
31
+ self.beta = beta
32
+ self.gamma = gamma
33
+ self.weight_tversky = weight_tversky
34
+ self.weight_focal = weight_focal
35
+
36
+ def forward(self, y_pred, y_true):
37
+ """
38
+ Compute the combined loss.
39
+
40
+ Args:
41
+ y_pred (Tensor): Predicted probabilities (B, C, ...).
42
+ y_true (Tensor): Ground truth labels (B, C, ...), one-hot encoded.
43
+
44
+ Returns:
45
+ Tensor: Weighted combination of Tversky and Focal losses.
46
+ """
47
+ tversky = self.tversky_loss(y_pred, y_true)
48
+ focal = self.focal_loss(y_pred, y_true)
49
+ return self.weight_tversky * tversky + self.weight_focal * focal
50
+
51
+ class FocalTverskyLoss(TverskyLoss):
52
+ def __init__(
53
+ self,
54
+ alpha=0.7, beta=0.3, gamma=1.0, smooth=1e-5, **kwargs):
55
+ """
56
+ Focal Tversky Loss with an additional power term for harder samples.
57
+
58
+ From https://arxiv.org/abs/1810.07842
59
+
60
+ Args:
61
+ alpha (float): Weight for false positives.
62
+ beta (float): Weight for false negatives.
63
+ gamma (float): Focus parameter (like Focal Loss).
64
+ smooth (float): Smoothing factor to avoid division by zero.
65
+ """
66
+ super().__init__(
67
+ alpha=alpha, beta=beta,
68
+ include_background=True,
69
+ to_onehot_y=True, softmax=True,
70
+ smooth_nr=smooth, smooth_dr=smooth, **kwargs)
71
+ self.gamma = gamma
72
+ self.alpha = alpha
73
+ self.beta = beta
74
+
75
+ def forward(self, y_pred, y_true):
76
+ """
77
+ Args:
78
+ y_pred (Tensor): Predicted probabilities (B, C, ...).
79
+ y_true (Tensor): Ground truth labels (B, C, ...), one-hot encoded.
80
+
81
+ Returns:
82
+ Tensor: Loss value.
83
+ """
84
+ tversky_loss = super().forward(y_pred, y_true)
85
+ modified_loss = torch.pow(tversky_loss, 1 / self.gamma)
86
+ return modified_loss
octopi/main.py ADDED
@@ -0,0 +1,101 @@
1
+ from octopi.processing.importers import cli_dataportal as download_dataportal
2
+ from octopi.processing.importers import cli_mrcs as import_mrc_volumes
3
+ from octopi.entry_points.run_create_targets import cli as create_targets
4
+ from octopi.entry_points.run_train import cli as train_model
5
+ from octopi.entry_points.run_optuna import cli as model_explore
6
+ from octopi.entry_points.run_segment_predict import cli as inference
7
+ from octopi.entry_points.run_localize import cli as localize
8
+ from octopi.entry_points.run_evaluate import cli as evaluate
9
+ from octopi.entry_points.run_extract_mb_picks import cli as extract_mb_picks
10
+ import octopi.entry_points.create_slurm_submission as slurm_submitter
11
+ import argparse
12
+ import sys
13
+
14
+ def cli_main():
15
+ """
16
+ Main CLI entry point for octopi.
17
+ """
18
+
19
+ # Create the main parser
20
+ parser = argparse.ArgumentParser(
21
+ description="Octopi 🐙: 🛠️ Tools for Finding Proteins in 🧊 cryo-ET data",
22
+ formatter_class=argparse.RawDescriptionHelpFormatter
23
+ )
24
+
25
+ # Create subparsers
26
+ subparsers = parser.add_subparsers(dest="command", help="Available commands")
27
+ subparsers.required = True # Make subcommand required
28
+
29
+ # Define all subcommands with their help text
30
+ commands = {
31
+ "import-mrc-volumes": (import_mrc_volumes, "Import MRC volumes from a directory, we can downsample to smaller voxel size if desired."),
32
+ "download-dataportal": (download_dataportal, "Download tomograms from the Dataportal, we can downsample to smaller voxel size if desired."),
33
+ "create-targets": (create_targets, "Generate segmentation targets from coordinates."),
34
+ "train": (train_model, "Train a single U-Net model."),
35
+ "model-explore": (model_explore, "Explore model architectures with Optuna / Bayesian Optimization."),
36
+ "inference": (inference, "Perform segmentation inference on tomograms."),
37
+ "localize": (localize, "Perform localization of particles in tomograms."),
38
+ "extract-mb-picks": (extract_mb_picks, "Extract MB Picks from tomograms."),
39
+ "evaluate": (evaluate, "Evaluate the performance of a model."),
40
+ }
41
+
42
+ # Add all subparsers and their help text
43
+ for cmd_name, (cmd_func, cmd_help) in commands.items():
44
+ subparsers.add_parser(cmd_name, help=cmd_help)
45
+
46
+ # Parse just the command part to determine which subcommand was chosen
47
+ if len(sys.argv) > 1 and sys.argv[1] in commands:
48
+ command = sys.argv[1]
49
+ cmd_func = commands[command][0]
50
+
51
+ # Remove the first argument (command name) and call the appropriate CLI function
52
+ sys.argv = [sys.argv[0]] + sys.argv[2:]
53
+ cmd_func()
54
+ else:
55
+ # Just show help if no valid command
56
+ parser.parse_args()
57
+
58
+ def cli_slurm_main():
59
+ """
60
+ SLURM-specific CLI entry point for octopi.
61
+ """
62
+
63
+ # Create the main parser
64
+ parser = argparse.ArgumentParser(
65
+ description="Octopi for SLURM 🖥️: Shell Submission Tools for Running 🐙 on HPC",
66
+ formatter_class=argparse.RawDescriptionHelpFormatter
67
+ )
68
+
69
+ # Create subparsers
70
+ subparsers = parser.add_subparsers(dest="command", help="Available commands")
71
+ subparsers.required = True # Make subcommand required
72
+
73
+ # Define all subcommands with their help text
74
+ commands = {
75
+ "import-mrc-volumes": (slurm_submitter.import_mrc_slurm, "Import MRC volumes from a directory."),
76
+ "download-dataportal": (slurm_submitter.download_dataportal_slurm, "Download tomograms from the Dataportal, we can downsample to smaller voxel size if desired."),
77
+ # "create-targets": (create_targets, "Generate segmentation targets from coordinates."),
78
+ "train": (slurm_submitter.train_model_slurm, "Train a single U-Net model."),
79
+ "model-explore": (slurm_submitter.model_explore_slurm, "Explore model architectures with Optuna / Bayesian Optimization."),
80
+ "inference": (slurm_submitter.inference_slurm, "Perform segmentation inference on tomograms."),
81
+ "localize": (slurm_submitter.localize_slurm, "Perform localization of particles in tomograms."),
82
+ # "extract-mb-picks": (extract_mb_picks, "Extract MB Picks from tomograms.")
83
+ # "evaluate": (evaluate, "Evaluate the performance of a model."),
84
+ }
85
+
86
+ # Add all subparsers and their help text
87
+ for cmd_name, (cmd_func, cmd_help) in commands.items():
88
+ subparsers.add_parser(cmd_name, help=cmd_help)
89
+
90
+ # Parse just the command part to determine which subcommand was chosen
91
+ if len(sys.argv) > 1 and sys.argv[1] in commands:
92
+ command = sys.argv[1]
93
+ cmd_func = commands[command][0]
94
+
95
+ # Remove the first argument (command name) and call the appropriate CLI function
96
+ sys.argv = [sys.argv[0]] + sys.argv[2:]
97
+ cmd_func()
98
+ else:
99
+ # Just show help if no valid command
100
+ parser.parse_args()
101
+