octopi 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of octopi might be problematic. Click here for more details.

Files changed (59) hide show
  1. octopi/__init__.py +0 -0
  2. octopi/datasets/__init__.py +0 -0
  3. octopi/datasets/augment.py +84 -0
  4. octopi/datasets/cached_datset.py +113 -0
  5. octopi/datasets/dataset.py +19 -0
  6. octopi/datasets/generators.py +429 -0
  7. octopi/datasets/mixup.py +49 -0
  8. octopi/datasets/multi_config_generator.py +253 -0
  9. octopi/entry_points/__init__.py +0 -0
  10. octopi/entry_points/common.py +80 -0
  11. octopi/entry_points/create_slurm_submission.py +243 -0
  12. octopi/entry_points/run_create_targets.py +281 -0
  13. octopi/entry_points/run_evaluate.py +65 -0
  14. octopi/entry_points/run_extract_mb_picks.py +141 -0
  15. octopi/entry_points/run_extract_midpoint.py +143 -0
  16. octopi/entry_points/run_localize.py +222 -0
  17. octopi/entry_points/run_optuna.py +139 -0
  18. octopi/entry_points/run_segment_predict.py +166 -0
  19. octopi/entry_points/run_train.py +201 -0
  20. octopi/extract/__init__.py +0 -0
  21. octopi/extract/localize.py +254 -0
  22. octopi/extract/membranebound_extract.py +262 -0
  23. octopi/extract/midpoint_extract.py +193 -0
  24. octopi/io.py +457 -0
  25. octopi/losses.py +86 -0
  26. octopi/main.py +101 -0
  27. octopi/models/AttentionUnet.py +56 -0
  28. octopi/models/MedNeXt.py +111 -0
  29. octopi/models/ModelTemplate.py +36 -0
  30. octopi/models/SegResNet.py +92 -0
  31. octopi/models/Unet.py +59 -0
  32. octopi/models/UnetPlusPlus.py +47 -0
  33. octopi/models/__init__.py +0 -0
  34. octopi/models/common.py +62 -0
  35. octopi/processing/__init__.py +0 -0
  36. octopi/processing/create_targets_from_picks.py +106 -0
  37. octopi/processing/downsample.py +129 -0
  38. octopi/processing/evaluate.py +289 -0
  39. octopi/processing/importers.py +213 -0
  40. octopi/processing/my_metrics.py +26 -0
  41. octopi/processing/segmentation_from_picks.py +167 -0
  42. octopi/processing/writers.py +102 -0
  43. octopi/pytorch/__init__.py +0 -0
  44. octopi/pytorch/hyper_search.py +243 -0
  45. octopi/pytorch/model_search_submitter.py +290 -0
  46. octopi/pytorch/segmentation.py +317 -0
  47. octopi/pytorch/trainer.py +438 -0
  48. octopi/pytorch_lightning/__init__.py +0 -0
  49. octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
  50. octopi/pytorch_lightning/train_pl.py +244 -0
  51. octopi/stopping_criteria.py +143 -0
  52. octopi/submit_slurm.py +95 -0
  53. octopi/utils.py +238 -0
  54. octopi/visualization_tools.py +201 -0
  55. octopi-1.0.dist-info/LICENSE +41 -0
  56. octopi-1.0.dist-info/METADATA +209 -0
  57. octopi-1.0.dist-info/RECORD +59 -0
  58. octopi-1.0.dist-info/WHEEL +4 -0
  59. octopi-1.0.dist-info/entry_points.txt +4 -0
@@ -0,0 +1,201 @@
1
+ from octopi.datasets import generators, multi_config_generator
2
+ from monai.losses import DiceLoss, FocalLoss, TverskyLoss
3
+ from octopi.models import common as builder
4
+ from monai.metrics import ConfusionMatrixMetric
5
+ from octopi.entry_points import common
6
+ from octopi.pytorch import trainer
7
+ from octopi import io, utils
8
+ import torch, os, argparse
9
+ from typing import List, Optional, Tuple
10
+ import pprint
11
+
12
+ def train_model(
13
+ copick_config_path: str,
14
+ target_info: Tuple[str, str, str],
15
+ tomo_algorithm: str = 'wbp',
16
+ voxel_size: float = 10,
17
+ trainRunIDs: List[str] = None,
18
+ validateRunIDs: List[str] = None,
19
+ model_config: str = None,
20
+ model_weights: Optional[str] = None,
21
+ model_save_path: str = 'results',
22
+ num_tomo_crops: int = 16,
23
+ tomo_batch_size: int = 15,
24
+ lr: float = 1e-3,
25
+ tversky_alpha: float = 0.5,
26
+ num_epochs: int = 100,
27
+ val_interval: int = 5,
28
+ best_metric: str = 'avg_f1',
29
+ data_split: str = '0.8'
30
+ ):
31
+
32
+ # Initialize the data generator to manage training and validation datasets
33
+ print(f'Training with {copick_config_path}\n')
34
+ if isinstance(copick_config_path, dict):
35
+ # Multi-config training
36
+ data_generator = multi_config_generator.MultiConfigTrainLoaderManager(
37
+ copick_config_path,
38
+ target_info[0],
39
+ target_session_id = target_info[2],
40
+ target_user_id = target_info[1],
41
+ tomo_algorithm = tomo_algorithm,
42
+ voxel_size = voxel_size,
43
+ Nclasses = model_config['num_classes'],
44
+ tomo_batch_size = tomo_batch_size )
45
+ else:
46
+ # Single-config training
47
+ data_generator = generators.TrainLoaderManager(
48
+ copick_config_path,
49
+ target_info[0],
50
+ target_session_id = target_info[2],
51
+ target_user_id = target_info[1],
52
+ tomo_algorithm = tomo_algorithm,
53
+ voxel_size = voxel_size,
54
+ Nclasses = model_config['num_classes'],
55
+ tomo_batch_size = tomo_batch_size )
56
+
57
+
58
+ # Get the data splits
59
+ ratios = utils.parse_data_split(data_split)
60
+ data_generator.get_data_splits(
61
+ trainRunIDs = trainRunIDs,
62
+ validateRunIDs = validateRunIDs,
63
+ train_ratio = ratios[0], val_ratio = ratios[1], test_ratio = ratios[2],
64
+ create_test_dataset = False)
65
+
66
+ # Get the reload frequency
67
+ data_generator.get_reload_frequency(num_epochs)
68
+
69
+ # Monai Functions
70
+ alpha = tversky_alpha
71
+ beta = 1 - alpha
72
+ loss_function = TverskyLoss(include_background=True, to_onehot_y=True, softmax=True, alpha=alpha, beta=beta)
73
+ metrics_function = ConfusionMatrixMetric(include_background=False, metric_name=["recall",'precision','f1 score'], reduction="none")
74
+
75
+ # Build the Model
76
+ model_builder = builder.get_model(model_config['architecture'])
77
+ model = model_builder.build_model(model_config)
78
+
79
+ # Load the Model Weights if Provided
80
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
81
+ if model_weights:
82
+ state_dict = torch.load(model_weights, map_location=device, weights_only=True)
83
+ model.load_state_dict(state_dict)
84
+ model.to(device)
85
+
86
+ # Optimizer
87
+ optimizer = torch.optim.AdamW(model.parameters(), lr, weight_decay=1e-4)
88
+
89
+ # Create UNet-Trainer
90
+ model_trainer = trainer.ModelTrainer(model, device, loss_function, metrics_function, optimizer)
91
+
92
+ results = model_trainer.train(
93
+ data_generator, model_save_path, max_epochs=num_epochs,
94
+ crop_size=model_config['dim_in'], my_num_samples=num_tomo_crops,
95
+ val_interval=val_interval, best_metric=best_metric, verbose=True
96
+ )
97
+
98
+ # Save parameters and results
99
+ parameters_save_name = os.path.join(model_save_path, "model_config.yaml")
100
+ io.save_parameters_to_yaml(model_builder, model_trainer, data_generator, parameters_save_name)
101
+
102
+ # TODO: Write Results to Zarr or Another File Format?
103
+ results_save_name = os.path.join(model_save_path, "results.json")
104
+ io.save_results_to_json(results, results_save_name)
105
+
106
+ def train_model_parser(parser_description, add_slurm: bool = False):
107
+ """
108
+ Parse the arguments for the training model
109
+ """
110
+ parser = argparse.ArgumentParser(
111
+ description=parser_description,
112
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
113
+ )
114
+ # Input Arguments
115
+ input_group = parser.add_argument_group("Input Arguments")
116
+ common.add_config(input_group, single_config=False)
117
+ input_group.add_argument("--target-info", type=utils.parse_target, default="targets,octopi,1",
118
+ help="Target information, e.g., 'name' or 'name,user_id,session_id'. Default is 'targets,octopi,1'.")
119
+ input_group.add_argument("--tomo-alg", default='wbp', help="Tomogram algorithm used for training")
120
+ input_group.add_argument("--trainRunIDs", type=utils.parse_list, help="List of training run IDs, e.g., run1,run2,run3")
121
+ input_group.add_argument("--validateRunIDs", type=utils.parse_list, help="List of validation run IDs, e.g., run4,run5,run6")
122
+ input_group.add_argument('--data-split', type=str, default='0.8', help="Data split ratios. Either a single value (e.g., '0.8' for 80/20/0 split) "
123
+ "or two comma-separated values (e.g., '0.7,0.1' for 70/10/20 split)")
124
+
125
+ fine_tune_group = parser.add_argument_group("Fine-Tuning Arguments")
126
+ fine_tune_group.add_argument('--model-config', type=str, help="Path to the model configuration file (typically used for fine-tuning)")
127
+ fine_tune_group.add_argument('--model-weights', type=str, help="Path to the model weights file (typically used for fine-tuning)")
128
+
129
+ # Model Arguments
130
+ model_group = parser.add_argument_group("UNet-Model Arguments")
131
+ common.add_model_parameters(model_group)
132
+
133
+ # Training Arguments
134
+ train_group = parser.add_argument_group("Training Arguments")
135
+ common.add_train_parameters(train_group)
136
+
137
+ # SLURM Arguments
138
+ if add_slurm:
139
+ slurm_group = parser.add_argument_group("SLURM Arguments")
140
+ common.add_slurm_parameters(slurm_group, 'train', gpus = 1)
141
+
142
+ args = parser.parse_args()
143
+ return args
144
+
145
+ # Entry point with argparse
146
+ def cli():
147
+ """
148
+ CLI entry point for training models where results can either be saved to a local directory or a server with MLFlow.
149
+ """
150
+
151
+ # Parse the arguments
152
+ parser_description = "Train 3D CNN U-Net models"
153
+ args = train_model_parser(parser_description)
154
+
155
+ # Parse the CoPick configuration paths
156
+ if len(args.config) > 1: copick_configs = utils.parse_copick_configs(args.config)
157
+ else: copick_configs = args.config[0]
158
+
159
+ if args.model_config:
160
+ model_config = utils.load_yaml(args.model_config)
161
+ else:
162
+ model_config = get_model_config(args.channels, args.strides, args.res_units, args.Nclass, args.dim_in)
163
+
164
+ # Call the training function
165
+ train_model(
166
+ copick_config_path=copick_configs,
167
+ target_info=args.target_info,
168
+ tomo_algorithm=args.tomo_alg,
169
+ voxel_size=args.voxel_size,
170
+ model_config=model_config,
171
+ model_weights=args.model_weights,
172
+ model_save_path=args.model_save_path,
173
+ num_tomo_crops=args.num_tomo_crops,
174
+ tomo_batch_size=args.tomo_batch_size,
175
+ lr=args.lr,
176
+ tversky_alpha=args.tversky_alpha,
177
+ num_epochs=args.num_epochs,
178
+ val_interval=args.val_interval,
179
+ best_metric=args.best_metric,
180
+ trainRunIDs=args.trainRunIDs,
181
+ validateRunIDs=args.validateRunIDs,
182
+ data_split=args.data_split
183
+ )
184
+
185
+ def get_model_config(channels, strides, res_units, Nclass, dim_in):
186
+ """
187
+ Create a model configuration dictionary if no model configuration file is provided.
188
+ """
189
+ model_config = {
190
+ 'architecture': 'Unet',
191
+ 'channels': channels,
192
+ 'strides': strides,
193
+ 'num_res_units': res_units,
194
+ 'num_classes': Nclass,
195
+ 'dropout': 0.1,
196
+ 'dim_in': dim_in
197
+ }
198
+ return model_config
199
+
200
+ if __name__ == "__main__":
201
+ cli()
File without changes
@@ -0,0 +1,254 @@
1
+ from skimage.morphology import binary_erosion, binary_dilation, ball
2
+ from scipy.cluster.hierarchy import fcluster, linkage
3
+ from skimage.segmentation import watershed
4
+ from typing import List, Optional, Tuple
5
+ from skimage.measure import regionprops
6
+ from scipy.spatial import distance
7
+ from dataclasses import dataclass
8
+ from octopi import io
9
+ import scipy.ndimage as ndi
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+ import math
13
+
14
+ def processs_localization(run,
15
+ objects,
16
+ seg_info: Tuple[str, str, str],
17
+ method: str = 'com',
18
+ voxel_size: float = 10,
19
+ filter_size: int = None,
20
+ radius_min_scale: float = 0.5,
21
+ radius_max_scale: float = 1.0,
22
+ pick_session_id: str = '1',
23
+ pick_user_id: str = 'monai'):
24
+
25
+ # Check if method is valid
26
+ if method not in ['watershed', 'com']:
27
+ raise ValueError(f"Invalid method '{method}'. Expected 'watershed' or 'com'.")
28
+
29
+ # Get Segmentation
30
+ seg = io.get_segmentation_array(run,
31
+ voxel_size,
32
+ seg_info[0],
33
+ user_id=seg_info[1],
34
+ session_id=seg_info[2],
35
+ raise_error=False)
36
+
37
+ # Preprocess Segmentation
38
+ # seg = preprocess_segmentation(seg, voxel_size, objects)
39
+
40
+ # If No Segmentation is Found, Return
41
+ if seg is None:
42
+ return
43
+
44
+ # Iterate through all user pickable objects
45
+ for obj in objects:
46
+
47
+ # Extract Particle Radius from Root
48
+ min_radius = obj[2] * radius_min_scale / voxel_size
49
+ max_radius = obj[2] * radius_max_scale / voxel_size
50
+
51
+ if method == 'watershed':
52
+ points = extract_particle_centroids_via_watershed(seg, obj[1], filter_size, min_radius, max_radius)
53
+ elif method == 'com':
54
+ points = extract_particle_centroids_via_com(seg, obj[1], min_radius, max_radius)
55
+ points = np.array(points)
56
+
57
+ # Save Coordinates if any 3D points are provided
58
+ if points.size > 2:
59
+
60
+ # Remove Picks that are too close to each other
61
+ # points = remove_repeated_picks(points, min_radius, pixelSize = voxel_size)
62
+
63
+ # Swap the coordinates to match the expected format
64
+ points = points[:,[2,1,0]]
65
+
66
+ # Convert the Picks back to Angstrom
67
+ points *= voxel_size
68
+
69
+ # Save Picks
70
+ try:
71
+ picks = run.new_picks(object_name = obj[0], session_id = pick_session_id, user_id=pick_user_id)
72
+ except:
73
+ picks = run.get_picks(object_name = obj[0], session_id = pick_session_id, user_id=pick_user_id)[0]
74
+
75
+ # Assign Identity As Orientation
76
+ orientations = np.zeros([points.shape[0], 4, 4])
77
+ orientations[:,:3,:3] = np.identity(3)
78
+ orientations[:,3,3] = 1
79
+
80
+ picks.from_numpy( points, orientations )
81
+ else:
82
+ print(f"{run.name} didn't have any available picks for {obj[0]}!")
83
+
84
+
85
+ def extract_particle_centroids_via_watershed(
86
+ segmentation,
87
+ segmentation_idx,
88
+ maxima_filter_size,
89
+ min_particle_radius,
90
+ max_particle_radius):
91
+ """
92
+ Process a specific label in the segmentation, extract centroids, and save them as picks.
93
+
94
+ Args:
95
+ segmentation (np.ndarray): Multilabel segmentation array.
96
+ segmentation_idx (int): The specific label from the segmentation to process.
97
+ maxima_filter_size (int): Size of the maximum detection filter.
98
+ min_particle_size (int): Minimum size threshold for particles.
99
+ max_particle_size (int): Maximum size threshold for particles.
100
+ """
101
+
102
+ if maxima_filter_size is None or maxima_filter_size < 0:
103
+ AssertionError('Enter a Non-Zero Filter Size!')
104
+
105
+ # Calculate minimum and maximum particle volumes based on the given radii
106
+ min_particle_size = (4 / 3) * np.pi * (min_particle_radius ** 3)
107
+ max_particle_size = (4 / 3) * np.pi * (max_particle_radius ** 3)
108
+
109
+ # Create a binary mask for the specific segmentation label
110
+ binary_mask = (segmentation == segmentation_idx).astype(int)
111
+
112
+ # Skip if the segmentation label is not present
113
+ if np.sum(binary_mask) == 0:
114
+ print(f"No segmentation with label {segmentation_idx} found.")
115
+ return
116
+
117
+ # Structuring element for erosion and dilation
118
+ struct_elem = ball(1)
119
+ eroded = binary_erosion(binary_mask, struct_elem)
120
+ dilated = binary_dilation(eroded, struct_elem)
121
+
122
+ # Distance transform and local maxima detection
123
+ distance = ndi.distance_transform_edt(dilated)
124
+ local_max = (distance == ndi.maximum_filter(distance, footprint=np.ones((maxima_filter_size, maxima_filter_size, maxima_filter_size))))
125
+
126
+ # Watershed segmentation
127
+ markers, _ = ndi.label(local_max)
128
+ watershed_labels = watershed(-distance, markers, mask=dilated)
129
+
130
+ # Extract region properties and filter based on particle size
131
+ all_centroids = []
132
+ for region in regionprops(watershed_labels):
133
+ if min_particle_size <= region.area <= max_particle_size:
134
+
135
+ # Option 1: Use all centroids
136
+ all_centroids.append(region.centroid)
137
+
138
+ return all_centroids
139
+
140
+ def extract_particle_centroids_via_com(
141
+ segmentation,
142
+ segmentation_idx,
143
+ min_particle_radius,
144
+ max_particle_radius
145
+ ):
146
+ """
147
+ Process a specific label in the segmentation, extract centroids, and save them as picks.
148
+
149
+ Args:
150
+ segmentation (np.ndarray): Multilabel segmentation array.
151
+ segmentation_idx (int): The specific label from the segmentation to process.
152
+ min_particle_size (int): Minimum size threshold for particles.
153
+ max_particle_size (int): Maximum size threshold for particles.
154
+ """
155
+
156
+ # Calculate minimum and maximum particle volumes based on the given radii
157
+ min_particle_size = (4 / 3) * np.pi * (min_particle_radius ** 3)
158
+ max_particle_size = (4 / 3) * np.pi * (max_particle_radius ** 3)
159
+
160
+ # Create a binary mask for the specific segmentation label
161
+ label_objs, _ = ndi.label(segmentation == segmentation_idx)
162
+
163
+ # Filter Candidates based on Object Size
164
+ # Get the sizes of all objects
165
+ object_sizes = np.bincount(label_objs.flat)
166
+
167
+ # Filter the objects based on size
168
+ valid_objects = np.where((object_sizes > min_particle_size) & (object_sizes < max_particle_size))[0]
169
+
170
+ # Estimate Coordiantes from CoM for LabelMaps
171
+ octopiCoords = []
172
+ for object_num in tqdm(valid_objects):
173
+ com = ndi.center_of_mass(label_objs == object_num)
174
+ swapped_com = (com[2], com[1], com[0])
175
+ octopiCoords.append(swapped_com)
176
+
177
+ return octopiCoords
178
+
179
+ def remove_repeated_picks(coordinates, distanceThreshold, pixelSize = 1):
180
+
181
+ # Calculate the distance matrix for the 3D coordinates
182
+ dist_matrix = distance.cdist(coordinates[:, :3]/pixelSize, coordinates[:, :3]/pixelSize)
183
+
184
+ # Create a linkage matrix using single linkage method
185
+ Z = linkage(dist_matrix, method='complete')
186
+
187
+ # Form flat clusters with a distance threshold to determine groups
188
+ clusters = fcluster(Z, t=distanceThreshold, criterion='distance')
189
+
190
+ # Initialize an array to store the average of each group
191
+ unique_coordinates = np.zeros((max(clusters), coordinates.shape[1]))
192
+
193
+ # Calculate the mean for each cluster
194
+ for i in range(1, max(clusters) + 1):
195
+ unique_coordinates[i-1] = np.mean(coordinates[clusters == i], axis=0)
196
+
197
+ return unique_coordinates
198
+
199
+ def preprocess_segmentation(segmentation, voxel_size, particle_info):
200
+ """
201
+ Remove tiny fragments that aren't real particles
202
+
203
+ Args:
204
+ segmentation (np.ndarray): The multilabel segmentation array
205
+ particle_info (list): List of tuples containing (name, segment_id, radius)
206
+
207
+ Returns:
208
+ np.ndarray: Processed segmentation with small fragments removed
209
+ """
210
+ import numpy as np
211
+ from skimage.morphology import remove_small_objects
212
+
213
+ processed_seg = segmentation.copy()
214
+
215
+ # Map segment IDs to particle types and their minimum sizes
216
+ segment_to_info = {}
217
+ for name, segment_id, radius in particle_info:
218
+ # # For small particles, use a larger minimum size
219
+ # if radius < 135:
220
+ # scale = 0.65
221
+ # # Normal threshold for other particles
222
+ # else:
223
+ # scale = 0.4
224
+ scale = 0.3
225
+ radius = radius / voxel_size
226
+ min_size = (4/3) * np.pi * ((radius * 0.5) ** 3)
227
+
228
+ segment_to_info[segment_id] = {
229
+ 'name': name,
230
+ 'min_size': min_size
231
+ }
232
+
233
+ # Get unique labels
234
+ unique_labels = np.unique(segmentation)
235
+ unique_labels = unique_labels[unique_labels > 0] # Skip background
236
+
237
+ # Process each label
238
+ for label in unique_labels:
239
+ if label not in segment_to_info:
240
+ continue
241
+
242
+ # Create binary mask for this label
243
+ mask = segmentation == label
244
+
245
+ # Get minimum size for this particle type
246
+ min_size = segment_to_info[label]['min_size']
247
+
248
+ # Remove small objects
249
+ cleaned_mask = remove_small_objects(mask, min_size=min_size * scale)
250
+
251
+ # Update segmentation
252
+ processed_seg[mask & ~cleaned_mask] = 0
253
+
254
+ return processed_seg