octopi 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. octopi/__init__.py +7 -0
  2. octopi/datasets/__init__.py +0 -0
  3. octopi/datasets/augment.py +83 -0
  4. octopi/datasets/cached_datset.py +113 -0
  5. octopi/datasets/dataset.py +19 -0
  6. octopi/datasets/generators.py +458 -0
  7. octopi/datasets/io.py +200 -0
  8. octopi/datasets/mixup.py +49 -0
  9. octopi/datasets/multi_config_generator.py +252 -0
  10. octopi/entry_points/__init__.py +0 -0
  11. octopi/entry_points/common.py +119 -0
  12. octopi/entry_points/create_slurm_submission.py +251 -0
  13. octopi/entry_points/groups.py +152 -0
  14. octopi/entry_points/run_create_targets.py +234 -0
  15. octopi/entry_points/run_evaluate.py +99 -0
  16. octopi/entry_points/run_extract_mb_picks.py +191 -0
  17. octopi/entry_points/run_extract_midpoint.py +143 -0
  18. octopi/entry_points/run_localize.py +176 -0
  19. octopi/entry_points/run_optuna.py +161 -0
  20. octopi/entry_points/run_segment.py +154 -0
  21. octopi/entry_points/run_train.py +189 -0
  22. octopi/extract/__init__.py +0 -0
  23. octopi/extract/localize.py +217 -0
  24. octopi/extract/membranebound_extract.py +263 -0
  25. octopi/extract/midpoint_extract.py +193 -0
  26. octopi/main.py +33 -0
  27. octopi/models/AttentionUnet.py +56 -0
  28. octopi/models/MedNeXt.py +111 -0
  29. octopi/models/ModelTemplate.py +36 -0
  30. octopi/models/SegResNet.py +92 -0
  31. octopi/models/Unet.py +59 -0
  32. octopi/models/UnetPlusPlus.py +47 -0
  33. octopi/models/__init__.py +0 -0
  34. octopi/models/common.py +72 -0
  35. octopi/processing/__init__.py +0 -0
  36. octopi/processing/create_targets_from_picks.py +224 -0
  37. octopi/processing/downloader.py +138 -0
  38. octopi/processing/downsample.py +125 -0
  39. octopi/processing/evaluate.py +302 -0
  40. octopi/processing/importers.py +116 -0
  41. octopi/processing/segmentation_from_picks.py +167 -0
  42. octopi/pytorch/__init__.py +0 -0
  43. octopi/pytorch/hyper_search.py +244 -0
  44. octopi/pytorch/model_search_submitter.py +291 -0
  45. octopi/pytorch/segmentation.py +363 -0
  46. octopi/pytorch/segmentation_multigpu.py +162 -0
  47. octopi/pytorch/trainer.py +465 -0
  48. octopi/pytorch_lightning/__init__.py +0 -0
  49. octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
  50. octopi/pytorch_lightning/train_pl.py +244 -0
  51. octopi/utils/__init__.py +0 -0
  52. octopi/utils/config.py +57 -0
  53. octopi/utils/io.py +215 -0
  54. octopi/utils/losses.py +86 -0
  55. octopi/utils/parsers.py +162 -0
  56. octopi/utils/progress.py +78 -0
  57. octopi/utils/stopping_criteria.py +143 -0
  58. octopi/utils/submit_slurm.py +95 -0
  59. octopi/utils/visualization_tools.py +290 -0
  60. octopi/workflows.py +262 -0
  61. octopi-1.4.0.dist-info/METADATA +119 -0
  62. octopi-1.4.0.dist-info/RECORD +65 -0
  63. octopi-1.4.0.dist-info/WHEEL +4 -0
  64. octopi-1.4.0.dist-info/entry_points.txt +3 -0
  65. octopi-1.4.0.dist-info/licenses/LICENSE +41 -0
@@ -0,0 +1,95 @@
1
+ def create_shellsubmit(
2
+ job_name,
3
+ output_file,
4
+ shell_name,
5
+ conda_path,
6
+ command,
7
+ num_gpus = 1,
8
+ gpu_constraint = 'h100'):
9
+
10
+ if num_gpus > 0:
11
+ slurm_gpus = f'#SBATCH --partition=gpu\n#SBATCH --gpus={gpu_constraint}:{num_gpus}'
12
+ else:
13
+ slurm_gpus = f'#SBATCH --partition=cpu'
14
+
15
+ shell_script_content = f"""#!/bin/bash
16
+
17
+ {slurm_gpus}
18
+ #SBATCH --time=18:00:00
19
+ #SBATCH --cpus-per-task=4
20
+ #SBATCH --mem-per-cpu=16G
21
+ #SBATCH --job-name={job_name}
22
+ #SBATCH --output={output_file}
23
+
24
+ ml anaconda
25
+ conda activate {conda_path}
26
+ {command}
27
+ """
28
+
29
+ # Save to file
30
+ with open(shell_name, "w") as file:
31
+ file.write(shell_script_content)
32
+
33
+ print(f"\nShell script has been created successfully as {shell_name}\n")
34
+
35
+ def create_shellsubmit_array(
36
+ job_name,
37
+ output_file,
38
+ shell_name,
39
+ conda_path,
40
+ command,
41
+ job_array = [min, max]):
42
+
43
+ shell_script_content = f"""#!/bin/bash
44
+
45
+ #SBATCH --time=18:00:00
46
+ #SBATCH --cpus-per-task=4
47
+ #SBATCH --mem-per-cpu=16G
48
+ #SBATCH --job-name={job_name}
49
+ #SBATCH --output={output_file}
50
+ #SBATCH --array={job_array[0]}-{job_array[1]}
51
+
52
+ ml anaconda
53
+ conda activate {conda_path}
54
+ {command}
55
+ """
56
+
57
+ # Save to file
58
+ with open(shell_name, "w") as file:
59
+ file.write(shell_script_content)
60
+
61
+ print(f"\nShell script has been created successfully as {shell_name}\n")
62
+
63
+ def create_multiconfig_shellsubmit(
64
+ job_name,
65
+ output_file,
66
+ shell_name,
67
+ conda_path,
68
+ base_inputs,
69
+ config_inputs,
70
+ command):
71
+
72
+ multiconfig = f"""#! /bin/bash
73
+
74
+ #SBATCH --job-name={job_name}
75
+ #SBATCH --time=24:00:00
76
+ #SBATCH --cpus-per-task=4
77
+ #SBATCH --mem-per-cpu=12G
78
+ #SBATCH --partition=cpu
79
+ #SBATCH --output={output_file}
80
+
81
+ ml anaconda
82
+ {conda_path}
83
+
84
+ {base_inputs}
85
+
86
+ {config_inputs}
87
+
88
+ {command}
89
+ """
90
+
91
+ # Save to file
92
+ with open(shell_name, "w") as file:
93
+ file.write(multiconfig)
94
+
95
+ print(f"\nShell script has been created successfully as {shell_name}\n")
@@ -0,0 +1,290 @@
1
+ from ipywidgets import interact, IntSlider, fixed
2
+ from copick_utils.io import readers
3
+ import matplotlib.colors as mcolors
4
+ from typing import Optional, List
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import copick
8
+
9
+ # Define the interactive function
10
+ def interact_3d_seg(vol, seg):
11
+ """
12
+ Interactively show the segmentation on a tomogram.
13
+
14
+ Args:
15
+ vol (numpy.ndarray): The tomogram to show the segmentation on.
16
+ seg (numpy.ndarray): The segmentation to show on the tomogram.
17
+ """
18
+
19
+ # Get the number of slices for the slider range
20
+ max_slices = vol.shape[0] - 1
21
+ middle_slice = int(max_slices // 2)
22
+
23
+ # Launch the Interactive Widget
24
+ interact(
25
+ show_tomo_segmentation,
26
+ tomo=fixed(vol), seg=fixed(seg),
27
+ vol_slice=IntSlider(min=0, max=max_slices, step=1, value=middle_slice)
28
+ )
29
+
30
+ # Define the plotting function
31
+ def show_tomo_segmentation(tomo, seg, vol_slice):
32
+ """
33
+ Show Segmentation on a Tomogram Slice.
34
+
35
+ Args:
36
+ tomo (numpy.ndarray): The tomogram to show the segmentation on.
37
+ seg (numpy.ndarray): The segmentation to show on the tomogram.
38
+ vol_slice (int): The slice index to show.
39
+ """
40
+
41
+ plt.figure(figsize=(20, 10))
42
+
43
+ # Tomogram
44
+ plt.subplot(1, 3, 1)
45
+ plt.title('Tomogram')
46
+ plt.imshow(tomo[vol_slice], cmap='gray')
47
+ plt.axis('off')
48
+
49
+ # Painted Segmentation
50
+ plt.subplot(1, 3, 2)
51
+ plt.title('Painted Segmentation from Picks')
52
+ plt.imshow(seg[vol_slice], cmap='viridis')
53
+ plt.axis('off')
54
+
55
+ plt.subplot(1, 3, 3)
56
+ plt.title('Overlay')
57
+ plt.imshow(tomo[vol_slice], cmap='gray')
58
+ plt.imshow(seg[vol_slice], cmap='viridis', alpha=0.5) # Add alpha=0.5 for 50% transparency
59
+ plt.axis('off')
60
+
61
+ plt.tight_layout()
62
+ plt.show()
63
+
64
+ def show_labeled_tomo_segmentation(tomo, seg, seg_labels, unique_values, vol_slice):
65
+
66
+ # # Check unique values in segmentation to ensure correct mapping
67
+ # unique_values = np.unique(seg)
68
+
69
+ plt.figure(figsize=(20, 10))
70
+ num_classes = len(seg_labels)
71
+
72
+ # Dynamically update the labels and colormap based on unique values
73
+ seg_labels_filtered = {k: v for k, v in seg_labels.items() if k in unique_values}
74
+ num_classes = len(seg_labels_filtered)
75
+
76
+ # Create a discrete colormap
77
+ colors = plt.cm.tab20b(np.linspace(0, 1, num_classes)) # You can use other colormaps like 'Set3', 'tab20', etc.
78
+ cmap = mcolors.ListedColormap(colors)
79
+ bounds = list(seg_labels_filtered.keys()) + [max(seg_labels_filtered.keys())]
80
+ # norm = mcolors.BoundaryNorm(bounds, cmap.N)
81
+
82
+ # Tomogram plot
83
+ plt.subplot(1, 2, 1)
84
+ plt.title('Tomogram')
85
+ plt.imshow(tomo[vol_slice], cmap='gray')
86
+ plt.axis('off')
87
+
88
+ # Prediction segmentation plot
89
+ plt.subplot(1, 2, 2)
90
+ plt.title('Prediction Segmentation')
91
+ im = plt.imshow(seg[vol_slice], cmap=cmap) # Use norm and cmap for segmentation
92
+ plt.axis('off')
93
+
94
+ # Add the labeled color bar
95
+ cbar = plt.colorbar(im, ticks=list(seg_labels_filtered.keys()))
96
+ cbar.ax.set_yticklabels([seg_labels_filtered[i] for i in seg_labels_filtered.keys()]) # Set custom labels
97
+
98
+ plt.tight_layout()
99
+ plt.show()
100
+
101
+ def interact_points(
102
+ tomo, config, run_id, user_id='octopi',
103
+ session_id = None, pt_size = 15,
104
+ slice_proximity_threshold = 3
105
+ ):
106
+ """
107
+ Interactively show the points on a tomogram.
108
+
109
+ Args:
110
+ tomo (numpy.ndarray): The tomogram to show the points on.
111
+ run_id (str): The ID of the run to show the points on.
112
+ user_id (str): The ID of the user to show the points on.
113
+ session_id (str): The ID of the session to show the points on.
114
+ slice_proximity_threshold (int): The threshold for the proximity of the points to the slice.
115
+ pt_size (int): The size of the points to show.
116
+ """
117
+
118
+ # Load Copick Project and Run
119
+ root = copick.from_file(config)
120
+ run = root.get_run(run_id)
121
+
122
+ # Get objects that can be Picked
123
+ objects = [(obj.name, obj.label, obj.radius) for obj in root.pickable_objects if obj.is_particle]
124
+
125
+ # Get the number of slices for the slider range
126
+ max_slices = tomo.shape[0] - 1
127
+ middle_slice = int(max_slices // 2)
128
+
129
+ # Launch the Interactive Widget
130
+ interact(
131
+ show_tomo_points,
132
+ tomo=fixed(tomo), run=fixed(run), objects=fixed(objects),
133
+ user_id=fixed(user_id), session_id=fixed(session_id),
134
+ slice_proximity_threshold=fixed(slice_proximity_threshold),
135
+ pt_size=fixed(pt_size),
136
+ vol_slice=IntSlider(min=0, max=max_slices, step=1, value=middle_slice)
137
+ )
138
+
139
+ def show_tomo_points(
140
+ tomo, run, objects, user_id,
141
+ vol_slice, session_id = None,
142
+ slice_proximity_threshold = 3, pt_size = 15
143
+ ):
144
+ """
145
+ Show Coordinates on a Tomogram Slice.
146
+
147
+ Args:
148
+ tomo (numpy.ndarray): The tomogram to show the points on.
149
+ run (copick.Run): The Copick Run to show the points from.
150
+ objects (list): List of pickable objects.
151
+ user_id (str): The ID of the user to show the points from.
152
+ vol_slice (int): The slice index to show.
153
+ session_id (str): The ID of the session to show the points from.
154
+ slice_proximity_threshold (int): The threshold for the proximity of the points to the slice
155
+ """
156
+
157
+
158
+ plt.figure(figsize=(20, 10))
159
+
160
+ plt.imshow(tomo[vol_slice],cmap='gray')
161
+ plt.axis('off')
162
+
163
+ for name,_,_ in objects:
164
+ try:
165
+ coordinates = readers.coordinates(run, name=name, user_id=user_id, session_id=session_id)
166
+ close_points = coordinates[np.abs(coordinates[:, 0] - vol_slice) <= slice_proximity_threshold]
167
+ plt.scatter(close_points[:, 2], close_points[:, 1], label=name, s=pt_size)
168
+ except:
169
+ pass
170
+
171
+ plt.show()
172
+
173
+ def compare_tomo_points(tomo, run, objects, vol_slice, user_id1, user_id2,
174
+ session_id1 = None, session_id2 = None, slice_proximity_threshold = 3):
175
+ plt.figure(figsize=(20, 10))
176
+
177
+ plt.subplot(1, 2, 1)
178
+ plt.imshow(tomo[vol_slice],cmap='gray')
179
+ plt.title(f'{user_id1} Picks')
180
+
181
+ for name,_,_ in objects:
182
+ try:
183
+ coordinates = readers.coordinates(run, name=name, user_id=user_id1, session_id=session_id1)
184
+ close_points = coordinates[np.abs(coordinates[:, 0] - vol_slice) <= slice_proximity_threshold]
185
+ plt.scatter(close_points[:, 2], close_points[:, 1], label=name, s=15)
186
+ except:
187
+ pass
188
+
189
+ plt.subplot(1, 2, 2)
190
+ plt.imshow(tomo[vol_slice],cmap='gray')
191
+ plt.title(f'{user_id2} Picks')
192
+
193
+ for name,_,_ in objects:
194
+ try:
195
+ coordinates = readers.coordinates(run, name=name, user_id=user_id2, session_id=session_id2)
196
+ close_points = coordinates[np.abs(coordinates[:, 0] - vol_slice) <= slice_proximity_threshold]
197
+ plt.scatter(close_points[:, 2], close_points[:, 1], label=name, s=15)
198
+ except:
199
+ pass
200
+
201
+ plt.axis('off')
202
+ plt.show()
203
+
204
+ def plot_training_results(
205
+ results,
206
+ class_names: Optional[List[str]] = None,
207
+ save_plot: str = None,
208
+ fig = None, axs = None):
209
+ """
210
+ Plot Training Results including Loss, Recall, Precision, and F1 Score.
211
+
212
+ Args:
213
+ results (dict): A dictionary containing training metrics.
214
+ class_names (list, optional): List of class names for labeling. Defaults to None.
215
+ save_plot (str): Path to save the plot image.
216
+ fig (matplotlib.figure.Figure, optional): Existing figure to plot on. Defaults to None.
217
+ axs (numpy.ndarray, optional): Existing axes to plot on. Defaults to None.
218
+ """
219
+
220
+ # Create a 2x2 subplot layout
221
+ if fig is None:
222
+ fig, axs = plt.subplots(2, 2, figsize=(12, 10))
223
+ else:
224
+ # Clear previos plots
225
+ for ax in axs.flatten():
226
+ ax.clear()
227
+
228
+ fig.suptitle("Metrics Over Epochs", fontsize=16)
229
+
230
+ # Unpack the data for loss (logged every epoch)
231
+ epochs_loss = [epoch for epoch, _ in results['loss']]
232
+ loss = [value for _, value in results['loss']]
233
+ val_epochs_loss = [epoch for epoch, _ in results['val_loss']]
234
+ val_loss = [value for _,value in results['val_loss']]
235
+
236
+ # Plot Training Loss in the top-left
237
+ axs[0, 0].plot(epochs_loss, loss, label="Training Loss")
238
+ axs[0, 0].plot(val_epochs_loss, val_loss, label='Validation Loss')
239
+ axs[0, 0].set_xlabel("Epochs")
240
+ axs[0, 0].set_ylabel("Loss")
241
+ axs[0, 0].set_title("Training Loss")
242
+ axs[0, 0].legend()
243
+ axs[0, 0].tick_params(axis='both', direction='in', top=True, right=True, length=6, width=1)
244
+
245
+ # For metrics that are logged every `val_interval` epochs
246
+ epochs_metrics = [epoch for epoch, _ in results['avg_recall']]
247
+
248
+ # Determine the number of classes and names
249
+ num_classes = len([key for key in results.keys() if key.startswith('recall_class')])
250
+
251
+ if class_names is None or len(class_names) != num_classes:
252
+ class_names = [f"Class {i+1}" for i in range(num_classes)]
253
+
254
+ # Plot Recall in the top-right
255
+ for class_idx in range(num_classes):
256
+ recall_class = [value for _, value in results[f'recall_class{class_idx+1}']]
257
+ axs[0, 1].plot(epochs_metrics, recall_class, label=f"{class_names[class_idx]}")
258
+ axs[0, 1].set_xlabel("Epochs")
259
+ axs[0, 1].set_ylabel("Recall")
260
+ axs[0, 1].set_title("Recall per Class")
261
+ # axs[0, 1].legend()
262
+ axs[0, 1].tick_params(axis='both', direction='in', top=True, right=True, length=6, width=1)
263
+
264
+ # Plot Precision in the bottom-left
265
+ for class_idx in range(num_classes):
266
+ precision_class = [value for _, value in results[f'precision_class{class_idx+1}']]
267
+ axs[1, 0].plot(epochs_metrics, precision_class, label=f"{class_names[class_idx]}")
268
+ axs[1, 0].set_xlabel("Epochs")
269
+ axs[1, 0].set_ylabel("Precision")
270
+ axs[1, 0].set_title("Precision per Class")
271
+ axs[1, 0].legend()
272
+ axs[1, 0].tick_params(axis='both', direction='in', top=True, right=True, length=6, width=1)
273
+
274
+ # Plot F1 Score in the bottom-right
275
+ for class_idx in range(num_classes):
276
+ f1_class = [value for _, value in results[f'f1_class{class_idx+1}']]
277
+ axs[1, 1].plot(epochs_metrics, f1_class, label=f"{class_names[class_idx]}")
278
+ axs[1, 1].set_xlabel("Epochs")
279
+ axs[1, 1].set_ylabel("F1 Score")
280
+ axs[1, 1].set_title("F1 Score per Class")
281
+ # axs[1, 1].legend()
282
+ axs[1, 1].tick_params(axis='both', direction='in', top=True, right=True, length=6, width=1)
283
+
284
+ # Adjust layout and show plot
285
+ plt.tight_layout(rect=[0, 0, 1, 0.96]) # Leave space for the main title
286
+
287
+ fig.savefig(save_plot)
288
+ fig.canvas.draw()
289
+
290
+ return fig, axs
octopi/workflows.py ADDED
@@ -0,0 +1,262 @@
1
+ from octopi.extract.localize import process_localization
2
+ import octopi.processing.evaluate as octopi_evaluate
3
+ from monai.metrics import ConfusionMatrixMetric
4
+ from octopi.models import common as builder
5
+ from octopi.pytorch import segmentation
6
+ from octopi.pytorch import trainer
7
+ from octopi.utils import io
8
+ import multiprocess as mp
9
+ import copick, torch, os
10
+ from tqdm import tqdm
11
+
12
+ def train(data_generator, loss_function, num_crops = 16,
13
+ model_config = None, model_weights = None, lr0 = 1e-3,
14
+ model_save_path = 'results', best_metric = 'fBeta2',
15
+ num_epochs = 1000, use_ema = True, val_interval = 10,
16
+ sw_bs = 4, overlap = 0.5, ):
17
+ """
18
+ Train a UNet Model for Segmentation
19
+
20
+ Args:
21
+ config (str): Path to the Copick Config File
22
+ target_info (list): List containing the target user ID, target session ID, and target algorithm
23
+ tomo_algorithm (str): The tomographic algorithm to use for segmentation
24
+ voxel_size (float): The voxel size of the data
25
+ loss_function (str): The loss function to use for training
26
+ model_config (dict): The model configuration
27
+ model_weights (str): The path to the model weights
28
+ trainRunIDs (list): The list of run IDs to use for training
29
+ validateRunIDs (list): The list of run IDs to use for validation
30
+ model_save_path (str): The path to save the model
31
+ best_metric (str): The metric to use for early stopping
32
+ num_epochs (int): The number of epochs to train for
33
+ val_interval (int): The number of epoch intervals for validation during training
34
+ sw_bs (int): The sliding window batch size for validation
35
+ overlap (float): The overlap for sliding window inference during validation
36
+ """
37
+
38
+ # If No Model Configuration is Provided, Use the Default Configuration
39
+ if model_config is None:
40
+ model_config = {
41
+ 'architecture': 'Unet',
42
+ 'num_classes': data_generator.Nclasses,
43
+ 'dim_in': 80,
44
+ 'strides': [2, 2, 1],
45
+ 'channels': [48, 64, 80, 80],
46
+ 'dropout': 0.0, 'num_res_units': 1,
47
+ }
48
+ print('No Model Configuration Provided, Using Default Configuration')
49
+ print(model_config)
50
+ # extract the model config from full config dict
51
+ elif isinstance(model_config, dict) and 'model' in model_config:
52
+ model_config = model_config['model']
53
+
54
+ # Monai Functions
55
+ metrics_function = ConfusionMatrixMetric(include_background=False, metric_name=["recall",'precision','f1 score'], reduction="none")
56
+
57
+ # Build the Model
58
+ model_builder = builder.get_model(model_config['architecture'])
59
+ model = model_builder.build_model(model_config)
60
+
61
+ # Load the Model Weights if Provided
62
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
63
+ if model_weights:
64
+ print(f'Loading Model Weights from: {model_weights}\n')
65
+ state_dict = torch.load(model_weights, map_location=device, weights_only=True)
66
+ model.load_state_dict(state_dict)
67
+ model.to(device)
68
+
69
+ # Optimizer
70
+ optimizer = torch.optim.AdamW(
71
+ model.parameters(), lr=lr0, weight_decay=1e-4
72
+ )
73
+
74
+ # Create UNet-Trainer
75
+ model_trainer = trainer.ModelTrainer(
76
+ model, device, loss_function, metrics_function,
77
+ optimizer, use_ema = use_ema
78
+ )
79
+ model_trainer.sw_bs = sw_bs
80
+ model_trainer.overlap = overlap
81
+
82
+ # Train the Model
83
+ print(f'🔃 Starting Training...\nSaving Training Results to: {model_save_path}/\n')
84
+ results = model_trainer.train(
85
+ data_generator, model_save_path, max_epochs=num_epochs,
86
+ crop_size=model_config['dim_in'], my_num_samples=num_crops,
87
+ val_interval=val_interval, best_metric=best_metric, verbose=True
88
+ )
89
+ print('✅ Training Complete!')
90
+
91
+ # Save parameters and results
92
+ print(f'💾 Saving Training Parameters and Results to: {model_save_path}/\n')
93
+ parameters_save_name = os.path.join(model_save_path, "model_config.yaml")
94
+ io.save_parameters_to_yaml(model_builder, model_trainer, data_generator, parameters_save_name)
95
+
96
+ # TODO: Write Results to CSV...
97
+ results_save_name = os.path.join(model_save_path, "results.csv")
98
+ io.save_results_to_csv(results, results_save_name)
99
+
100
+ def segment(config, tomo_algorithm, voxel_size, model_weights, model_config,
101
+ seg_info = ['predict', 'octopi', '1'], use_tta = False, run_ids = None):
102
+ """
103
+ Segment a Dataset using a Trained Model or Ensemble of Models
104
+
105
+ Args:
106
+ config (str): Path to the Copick Config File
107
+ tomo_algorithm (str): The tomographic algorithm to use for segmentation
108
+ voxel_size (float): The voxel size of the data
109
+ model_weights (str, list): The path to the model weights or a list of paths to the model weights
110
+ model_config (str, list): The model configuration or a list of model configurations
111
+ seg_info (list): The segmentation information
112
+ use_tta (bool): Whether to use test time augmentation
113
+ run_ids (list): The list of run IDs to use for segmentation
114
+ """
115
+
116
+ # Initialize the Predictor
117
+ predict = segmentation.Predictor(
118
+ config, model_config, model_weights,
119
+ apply_tta = use_tta
120
+ )
121
+
122
+ # Run batch prediction
123
+ predict.batch_predict(
124
+ runIDs=run_ids,
125
+ num_tomos_per_batch=15,
126
+ tomo_algorithm=tomo_algorithm,
127
+ voxel_spacing=voxel_size,
128
+ segmentation_name=seg_info[0],
129
+ segmentation_user_id=seg_info[1],
130
+ segmentation_session_id=seg_info[2]
131
+ )
132
+
133
+ print('✅ Segmentation Complete!')
134
+
135
+ def localize(config, voxel_size, seg_info, pick_user_id, pick_session_id, n_procs = 16,
136
+ method = 'watershed', filter_size = 10, radius_min_scale = 0.4, radius_max_scale = 1.0,
137
+ run_ids = None, pick_objects = None):
138
+ """
139
+ Extract 3D Coordinates from the Segmentation Maps
140
+
141
+ Args:
142
+ config (str): Path to the Copick Config File
143
+ voxel_size (float): The voxel size of the data
144
+ seg_info (list): The segmentation information
145
+ pick_user_id (str): The user ID of the pick
146
+ pick_session_id (str): The session ID of the pick
147
+ n_procs (int): The number of processes to use for parallelization
148
+ method (str): The method to use for localization
149
+ filter_size (int): The filter size to use for localization
150
+ radius_min_scale (float): The minimum radius scale to use for localization
151
+ radius_max_scale (float): The maximum radius scale to use for localization
152
+ run_ids (list): The list of run IDs to use for localization
153
+ """
154
+
155
+ # Load the Copick Config
156
+ root = copick.from_file(config)
157
+
158
+ # Get objects that can be Picked build into mutable rows
159
+ objects = [[obj.name, int(obj.label), float(obj.radius)]
160
+ for obj in root.pickable_objects if obj.is_particle]
161
+
162
+ # Verify each object has the required attributes
163
+ for obj in objects:
164
+ if len(obj) < 3 or not isinstance(obj[2], (float, int)):
165
+ raise ValueError(f"Invalid object format: {obj}. Expected a tuple with (name, label, radius).")
166
+
167
+ # Load the Model Output Configuration
168
+ seg_config = io.get_config(config, seg_info[0], 'segment', seg_info[1], seg_info[2])
169
+
170
+ # sync labels from the model config and remove objects not in model labels
171
+ label_map = seg_config.get('labels', {})
172
+ for row in objects.copy(): # avoid modifying the list while iterating
173
+ name, label, radius = row
174
+ if name in label_map and label != label_map[name]:
175
+ row[1] = int(label_map[name]) # mutate in place
176
+ elif name not in label_map: # remove this entry from objects
177
+ objects.remove(row)
178
+
179
+ # Filter objects based on the provided list
180
+ if pick_objects is not None:
181
+ objects0 = objects.copy() # avoid modifying the list while iterating for error tracking
182
+ objects = [obj for obj in objects if obj[0] in pick_objects]
183
+ if len(objects) == 0:
184
+ raise ValueError(f"No valid objects found for localization after filtering. Mismatched names: {pick_objects} and {[obj[0] for obj in objects0]}")
185
+
186
+ # Get all RunIDs
187
+ if run_ids is None:
188
+ run_ids = [run.name for run in root.runs]
189
+ n_run_ids = len(run_ids)
190
+
191
+ # Exit if No Runs are Available
192
+ if n_run_ids == 0:
193
+ print(f"No runs available for localization with the specified voxel size - {voxel_size}.\nExiting...")
194
+ return
195
+
196
+ # Run Localization - Main Parallelization Loop
197
+ print(f"Using {n_procs} processes to parallelize across {n_run_ids} run IDs.")
198
+ with mp.Pool(processes=n_procs) as pool:
199
+ with tqdm(total=n_run_ids, desc="Localization", unit="run") as pbar:
200
+ worker_func = lambda run_id: process_localization(
201
+ root.get_run(run_id),
202
+ objects,
203
+ seg_info,
204
+ method,
205
+ voxel_size,
206
+ filter_size,
207
+ radius_min_scale,
208
+ radius_max_scale,
209
+ pick_session_id,
210
+ pick_user_id
211
+ )
212
+
213
+ for _ in pool.imap_unordered(worker_func, run_ids, chunksize=1):
214
+ pbar.update(1)
215
+
216
+ print('✅ Localization Complete!')
217
+
218
+
219
+ def evaluate(config,
220
+ gt_user_id, gt_session_id,
221
+ pred_user_id, pred_session_id,
222
+ run_ids = None, distance_threshold = 0.5, save_path = None):
223
+ """
224
+ Evaluate the Localization on a Dataset
225
+
226
+ Args:
227
+ config (str): Path to the Copick Config File
228
+ gt_user_id (str): The user ID of the ground truth
229
+ gt_session_id (str): The session ID of the ground truth
230
+ pred_user_id (str): The user ID of the predicted coordinates
231
+ pred_session_id (str): The session ID of the predicted coordinates
232
+ run_ids (list): The list of run IDs to use for evaluation
233
+ distance_threshold (float): The distance threshold to use for evaluation
234
+ save_path (str): The path to save the evaluation results
235
+ """
236
+
237
+ print('Running Evaluation on the Following Query:')
238
+ print(f'Distance Threshold: {distance_threshold}')
239
+ print(f'GT User ID: {gt_user_id}, GT Session ID: {gt_session_id}')
240
+ print(f'Pred User ID: {pred_user_id}, Pred Session ID: {pred_session_id}')
241
+ print(f'Run IDs: {run_ids}')
242
+
243
+ # Load the Copick Config
244
+ root = copick.from_file(config)
245
+
246
+ # For Now Lets Assume Object Names are None..
247
+ object_names = None
248
+
249
+ # Run Evaluation
250
+ eval = octopi_evaluate.evaluator(
251
+ config,
252
+ gt_user_id,
253
+ gt_session_id,
254
+ pred_user_id,
255
+ pred_session_id,
256
+ object_names=object_names
257
+ )
258
+
259
+ eval.run(
260
+ distance_threshold_scale=distance_threshold,
261
+ runIDs=run_ids, save_path=save_path
262
+ )