octopi 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- octopi/__init__.py +7 -0
- octopi/datasets/__init__.py +0 -0
- octopi/datasets/augment.py +83 -0
- octopi/datasets/cached_datset.py +113 -0
- octopi/datasets/dataset.py +19 -0
- octopi/datasets/generators.py +458 -0
- octopi/datasets/io.py +200 -0
- octopi/datasets/mixup.py +49 -0
- octopi/datasets/multi_config_generator.py +252 -0
- octopi/entry_points/__init__.py +0 -0
- octopi/entry_points/common.py +119 -0
- octopi/entry_points/create_slurm_submission.py +251 -0
- octopi/entry_points/groups.py +152 -0
- octopi/entry_points/run_create_targets.py +234 -0
- octopi/entry_points/run_evaluate.py +99 -0
- octopi/entry_points/run_extract_mb_picks.py +191 -0
- octopi/entry_points/run_extract_midpoint.py +143 -0
- octopi/entry_points/run_localize.py +176 -0
- octopi/entry_points/run_optuna.py +161 -0
- octopi/entry_points/run_segment.py +154 -0
- octopi/entry_points/run_train.py +189 -0
- octopi/extract/__init__.py +0 -0
- octopi/extract/localize.py +217 -0
- octopi/extract/membranebound_extract.py +263 -0
- octopi/extract/midpoint_extract.py +193 -0
- octopi/main.py +33 -0
- octopi/models/AttentionUnet.py +56 -0
- octopi/models/MedNeXt.py +111 -0
- octopi/models/ModelTemplate.py +36 -0
- octopi/models/SegResNet.py +92 -0
- octopi/models/Unet.py +59 -0
- octopi/models/UnetPlusPlus.py +47 -0
- octopi/models/__init__.py +0 -0
- octopi/models/common.py +72 -0
- octopi/processing/__init__.py +0 -0
- octopi/processing/create_targets_from_picks.py +224 -0
- octopi/processing/downloader.py +138 -0
- octopi/processing/downsample.py +125 -0
- octopi/processing/evaluate.py +302 -0
- octopi/processing/importers.py +116 -0
- octopi/processing/segmentation_from_picks.py +167 -0
- octopi/pytorch/__init__.py +0 -0
- octopi/pytorch/hyper_search.py +244 -0
- octopi/pytorch/model_search_submitter.py +291 -0
- octopi/pytorch/segmentation.py +363 -0
- octopi/pytorch/segmentation_multigpu.py +162 -0
- octopi/pytorch/trainer.py +465 -0
- octopi/pytorch_lightning/__init__.py +0 -0
- octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
- octopi/pytorch_lightning/train_pl.py +244 -0
- octopi/utils/__init__.py +0 -0
- octopi/utils/config.py +57 -0
- octopi/utils/io.py +215 -0
- octopi/utils/losses.py +86 -0
- octopi/utils/parsers.py +162 -0
- octopi/utils/progress.py +78 -0
- octopi/utils/stopping_criteria.py +143 -0
- octopi/utils/submit_slurm.py +95 -0
- octopi/utils/visualization_tools.py +290 -0
- octopi/workflows.py +262 -0
- octopi-1.4.0.dist-info/METADATA +119 -0
- octopi-1.4.0.dist-info/RECORD +65 -0
- octopi-1.4.0.dist-info/WHEEL +4 -0
- octopi-1.4.0.dist-info/entry_points.txt +3 -0
- octopi-1.4.0.dist-info/licenses/LICENSE +41 -0
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
def create_shellsubmit(
|
|
2
|
+
job_name,
|
|
3
|
+
output_file,
|
|
4
|
+
shell_name,
|
|
5
|
+
conda_path,
|
|
6
|
+
command,
|
|
7
|
+
num_gpus = 1,
|
|
8
|
+
gpu_constraint = 'h100'):
|
|
9
|
+
|
|
10
|
+
if num_gpus > 0:
|
|
11
|
+
slurm_gpus = f'#SBATCH --partition=gpu\n#SBATCH --gpus={gpu_constraint}:{num_gpus}'
|
|
12
|
+
else:
|
|
13
|
+
slurm_gpus = f'#SBATCH --partition=cpu'
|
|
14
|
+
|
|
15
|
+
shell_script_content = f"""#!/bin/bash
|
|
16
|
+
|
|
17
|
+
{slurm_gpus}
|
|
18
|
+
#SBATCH --time=18:00:00
|
|
19
|
+
#SBATCH --cpus-per-task=4
|
|
20
|
+
#SBATCH --mem-per-cpu=16G
|
|
21
|
+
#SBATCH --job-name={job_name}
|
|
22
|
+
#SBATCH --output={output_file}
|
|
23
|
+
|
|
24
|
+
ml anaconda
|
|
25
|
+
conda activate {conda_path}
|
|
26
|
+
{command}
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
# Save to file
|
|
30
|
+
with open(shell_name, "w") as file:
|
|
31
|
+
file.write(shell_script_content)
|
|
32
|
+
|
|
33
|
+
print(f"\nShell script has been created successfully as {shell_name}\n")
|
|
34
|
+
|
|
35
|
+
def create_shellsubmit_array(
|
|
36
|
+
job_name,
|
|
37
|
+
output_file,
|
|
38
|
+
shell_name,
|
|
39
|
+
conda_path,
|
|
40
|
+
command,
|
|
41
|
+
job_array = [min, max]):
|
|
42
|
+
|
|
43
|
+
shell_script_content = f"""#!/bin/bash
|
|
44
|
+
|
|
45
|
+
#SBATCH --time=18:00:00
|
|
46
|
+
#SBATCH --cpus-per-task=4
|
|
47
|
+
#SBATCH --mem-per-cpu=16G
|
|
48
|
+
#SBATCH --job-name={job_name}
|
|
49
|
+
#SBATCH --output={output_file}
|
|
50
|
+
#SBATCH --array={job_array[0]}-{job_array[1]}
|
|
51
|
+
|
|
52
|
+
ml anaconda
|
|
53
|
+
conda activate {conda_path}
|
|
54
|
+
{command}
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
# Save to file
|
|
58
|
+
with open(shell_name, "w") as file:
|
|
59
|
+
file.write(shell_script_content)
|
|
60
|
+
|
|
61
|
+
print(f"\nShell script has been created successfully as {shell_name}\n")
|
|
62
|
+
|
|
63
|
+
def create_multiconfig_shellsubmit(
|
|
64
|
+
job_name,
|
|
65
|
+
output_file,
|
|
66
|
+
shell_name,
|
|
67
|
+
conda_path,
|
|
68
|
+
base_inputs,
|
|
69
|
+
config_inputs,
|
|
70
|
+
command):
|
|
71
|
+
|
|
72
|
+
multiconfig = f"""#! /bin/bash
|
|
73
|
+
|
|
74
|
+
#SBATCH --job-name={job_name}
|
|
75
|
+
#SBATCH --time=24:00:00
|
|
76
|
+
#SBATCH --cpus-per-task=4
|
|
77
|
+
#SBATCH --mem-per-cpu=12G
|
|
78
|
+
#SBATCH --partition=cpu
|
|
79
|
+
#SBATCH --output={output_file}
|
|
80
|
+
|
|
81
|
+
ml anaconda
|
|
82
|
+
{conda_path}
|
|
83
|
+
|
|
84
|
+
{base_inputs}
|
|
85
|
+
|
|
86
|
+
{config_inputs}
|
|
87
|
+
|
|
88
|
+
{command}
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
# Save to file
|
|
92
|
+
with open(shell_name, "w") as file:
|
|
93
|
+
file.write(multiconfig)
|
|
94
|
+
|
|
95
|
+
print(f"\nShell script has been created successfully as {shell_name}\n")
|
|
@@ -0,0 +1,290 @@
|
|
|
1
|
+
from ipywidgets import interact, IntSlider, fixed
|
|
2
|
+
from copick_utils.io import readers
|
|
3
|
+
import matplotlib.colors as mcolors
|
|
4
|
+
from typing import Optional, List
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
import numpy as np
|
|
7
|
+
import copick
|
|
8
|
+
|
|
9
|
+
# Define the interactive function
|
|
10
|
+
def interact_3d_seg(vol, seg):
|
|
11
|
+
"""
|
|
12
|
+
Interactively show the segmentation on a tomogram.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
vol (numpy.ndarray): The tomogram to show the segmentation on.
|
|
16
|
+
seg (numpy.ndarray): The segmentation to show on the tomogram.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
# Get the number of slices for the slider range
|
|
20
|
+
max_slices = vol.shape[0] - 1
|
|
21
|
+
middle_slice = int(max_slices // 2)
|
|
22
|
+
|
|
23
|
+
# Launch the Interactive Widget
|
|
24
|
+
interact(
|
|
25
|
+
show_tomo_segmentation,
|
|
26
|
+
tomo=fixed(vol), seg=fixed(seg),
|
|
27
|
+
vol_slice=IntSlider(min=0, max=max_slices, step=1, value=middle_slice)
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
# Define the plotting function
|
|
31
|
+
def show_tomo_segmentation(tomo, seg, vol_slice):
|
|
32
|
+
"""
|
|
33
|
+
Show Segmentation on a Tomogram Slice.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
tomo (numpy.ndarray): The tomogram to show the segmentation on.
|
|
37
|
+
seg (numpy.ndarray): The segmentation to show on the tomogram.
|
|
38
|
+
vol_slice (int): The slice index to show.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
plt.figure(figsize=(20, 10))
|
|
42
|
+
|
|
43
|
+
# Tomogram
|
|
44
|
+
plt.subplot(1, 3, 1)
|
|
45
|
+
plt.title('Tomogram')
|
|
46
|
+
plt.imshow(tomo[vol_slice], cmap='gray')
|
|
47
|
+
plt.axis('off')
|
|
48
|
+
|
|
49
|
+
# Painted Segmentation
|
|
50
|
+
plt.subplot(1, 3, 2)
|
|
51
|
+
plt.title('Painted Segmentation from Picks')
|
|
52
|
+
plt.imshow(seg[vol_slice], cmap='viridis')
|
|
53
|
+
plt.axis('off')
|
|
54
|
+
|
|
55
|
+
plt.subplot(1, 3, 3)
|
|
56
|
+
plt.title('Overlay')
|
|
57
|
+
plt.imshow(tomo[vol_slice], cmap='gray')
|
|
58
|
+
plt.imshow(seg[vol_slice], cmap='viridis', alpha=0.5) # Add alpha=0.5 for 50% transparency
|
|
59
|
+
plt.axis('off')
|
|
60
|
+
|
|
61
|
+
plt.tight_layout()
|
|
62
|
+
plt.show()
|
|
63
|
+
|
|
64
|
+
def show_labeled_tomo_segmentation(tomo, seg, seg_labels, unique_values, vol_slice):
|
|
65
|
+
|
|
66
|
+
# # Check unique values in segmentation to ensure correct mapping
|
|
67
|
+
# unique_values = np.unique(seg)
|
|
68
|
+
|
|
69
|
+
plt.figure(figsize=(20, 10))
|
|
70
|
+
num_classes = len(seg_labels)
|
|
71
|
+
|
|
72
|
+
# Dynamically update the labels and colormap based on unique values
|
|
73
|
+
seg_labels_filtered = {k: v for k, v in seg_labels.items() if k in unique_values}
|
|
74
|
+
num_classes = len(seg_labels_filtered)
|
|
75
|
+
|
|
76
|
+
# Create a discrete colormap
|
|
77
|
+
colors = plt.cm.tab20b(np.linspace(0, 1, num_classes)) # You can use other colormaps like 'Set3', 'tab20', etc.
|
|
78
|
+
cmap = mcolors.ListedColormap(colors)
|
|
79
|
+
bounds = list(seg_labels_filtered.keys()) + [max(seg_labels_filtered.keys())]
|
|
80
|
+
# norm = mcolors.BoundaryNorm(bounds, cmap.N)
|
|
81
|
+
|
|
82
|
+
# Tomogram plot
|
|
83
|
+
plt.subplot(1, 2, 1)
|
|
84
|
+
plt.title('Tomogram')
|
|
85
|
+
plt.imshow(tomo[vol_slice], cmap='gray')
|
|
86
|
+
plt.axis('off')
|
|
87
|
+
|
|
88
|
+
# Prediction segmentation plot
|
|
89
|
+
plt.subplot(1, 2, 2)
|
|
90
|
+
plt.title('Prediction Segmentation')
|
|
91
|
+
im = plt.imshow(seg[vol_slice], cmap=cmap) # Use norm and cmap for segmentation
|
|
92
|
+
plt.axis('off')
|
|
93
|
+
|
|
94
|
+
# Add the labeled color bar
|
|
95
|
+
cbar = plt.colorbar(im, ticks=list(seg_labels_filtered.keys()))
|
|
96
|
+
cbar.ax.set_yticklabels([seg_labels_filtered[i] for i in seg_labels_filtered.keys()]) # Set custom labels
|
|
97
|
+
|
|
98
|
+
plt.tight_layout()
|
|
99
|
+
plt.show()
|
|
100
|
+
|
|
101
|
+
def interact_points(
|
|
102
|
+
tomo, config, run_id, user_id='octopi',
|
|
103
|
+
session_id = None, pt_size = 15,
|
|
104
|
+
slice_proximity_threshold = 3
|
|
105
|
+
):
|
|
106
|
+
"""
|
|
107
|
+
Interactively show the points on a tomogram.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
tomo (numpy.ndarray): The tomogram to show the points on.
|
|
111
|
+
run_id (str): The ID of the run to show the points on.
|
|
112
|
+
user_id (str): The ID of the user to show the points on.
|
|
113
|
+
session_id (str): The ID of the session to show the points on.
|
|
114
|
+
slice_proximity_threshold (int): The threshold for the proximity of the points to the slice.
|
|
115
|
+
pt_size (int): The size of the points to show.
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
# Load Copick Project and Run
|
|
119
|
+
root = copick.from_file(config)
|
|
120
|
+
run = root.get_run(run_id)
|
|
121
|
+
|
|
122
|
+
# Get objects that can be Picked
|
|
123
|
+
objects = [(obj.name, obj.label, obj.radius) for obj in root.pickable_objects if obj.is_particle]
|
|
124
|
+
|
|
125
|
+
# Get the number of slices for the slider range
|
|
126
|
+
max_slices = tomo.shape[0] - 1
|
|
127
|
+
middle_slice = int(max_slices // 2)
|
|
128
|
+
|
|
129
|
+
# Launch the Interactive Widget
|
|
130
|
+
interact(
|
|
131
|
+
show_tomo_points,
|
|
132
|
+
tomo=fixed(tomo), run=fixed(run), objects=fixed(objects),
|
|
133
|
+
user_id=fixed(user_id), session_id=fixed(session_id),
|
|
134
|
+
slice_proximity_threshold=fixed(slice_proximity_threshold),
|
|
135
|
+
pt_size=fixed(pt_size),
|
|
136
|
+
vol_slice=IntSlider(min=0, max=max_slices, step=1, value=middle_slice)
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
def show_tomo_points(
|
|
140
|
+
tomo, run, objects, user_id,
|
|
141
|
+
vol_slice, session_id = None,
|
|
142
|
+
slice_proximity_threshold = 3, pt_size = 15
|
|
143
|
+
):
|
|
144
|
+
"""
|
|
145
|
+
Show Coordinates on a Tomogram Slice.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
tomo (numpy.ndarray): The tomogram to show the points on.
|
|
149
|
+
run (copick.Run): The Copick Run to show the points from.
|
|
150
|
+
objects (list): List of pickable objects.
|
|
151
|
+
user_id (str): The ID of the user to show the points from.
|
|
152
|
+
vol_slice (int): The slice index to show.
|
|
153
|
+
session_id (str): The ID of the session to show the points from.
|
|
154
|
+
slice_proximity_threshold (int): The threshold for the proximity of the points to the slice
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
plt.figure(figsize=(20, 10))
|
|
159
|
+
|
|
160
|
+
plt.imshow(tomo[vol_slice],cmap='gray')
|
|
161
|
+
plt.axis('off')
|
|
162
|
+
|
|
163
|
+
for name,_,_ in objects:
|
|
164
|
+
try:
|
|
165
|
+
coordinates = readers.coordinates(run, name=name, user_id=user_id, session_id=session_id)
|
|
166
|
+
close_points = coordinates[np.abs(coordinates[:, 0] - vol_slice) <= slice_proximity_threshold]
|
|
167
|
+
plt.scatter(close_points[:, 2], close_points[:, 1], label=name, s=pt_size)
|
|
168
|
+
except:
|
|
169
|
+
pass
|
|
170
|
+
|
|
171
|
+
plt.show()
|
|
172
|
+
|
|
173
|
+
def compare_tomo_points(tomo, run, objects, vol_slice, user_id1, user_id2,
|
|
174
|
+
session_id1 = None, session_id2 = None, slice_proximity_threshold = 3):
|
|
175
|
+
plt.figure(figsize=(20, 10))
|
|
176
|
+
|
|
177
|
+
plt.subplot(1, 2, 1)
|
|
178
|
+
plt.imshow(tomo[vol_slice],cmap='gray')
|
|
179
|
+
plt.title(f'{user_id1} Picks')
|
|
180
|
+
|
|
181
|
+
for name,_,_ in objects:
|
|
182
|
+
try:
|
|
183
|
+
coordinates = readers.coordinates(run, name=name, user_id=user_id1, session_id=session_id1)
|
|
184
|
+
close_points = coordinates[np.abs(coordinates[:, 0] - vol_slice) <= slice_proximity_threshold]
|
|
185
|
+
plt.scatter(close_points[:, 2], close_points[:, 1], label=name, s=15)
|
|
186
|
+
except:
|
|
187
|
+
pass
|
|
188
|
+
|
|
189
|
+
plt.subplot(1, 2, 2)
|
|
190
|
+
plt.imshow(tomo[vol_slice],cmap='gray')
|
|
191
|
+
plt.title(f'{user_id2} Picks')
|
|
192
|
+
|
|
193
|
+
for name,_,_ in objects:
|
|
194
|
+
try:
|
|
195
|
+
coordinates = readers.coordinates(run, name=name, user_id=user_id2, session_id=session_id2)
|
|
196
|
+
close_points = coordinates[np.abs(coordinates[:, 0] - vol_slice) <= slice_proximity_threshold]
|
|
197
|
+
plt.scatter(close_points[:, 2], close_points[:, 1], label=name, s=15)
|
|
198
|
+
except:
|
|
199
|
+
pass
|
|
200
|
+
|
|
201
|
+
plt.axis('off')
|
|
202
|
+
plt.show()
|
|
203
|
+
|
|
204
|
+
def plot_training_results(
|
|
205
|
+
results,
|
|
206
|
+
class_names: Optional[List[str]] = None,
|
|
207
|
+
save_plot: str = None,
|
|
208
|
+
fig = None, axs = None):
|
|
209
|
+
"""
|
|
210
|
+
Plot Training Results including Loss, Recall, Precision, and F1 Score.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
results (dict): A dictionary containing training metrics.
|
|
214
|
+
class_names (list, optional): List of class names for labeling. Defaults to None.
|
|
215
|
+
save_plot (str): Path to save the plot image.
|
|
216
|
+
fig (matplotlib.figure.Figure, optional): Existing figure to plot on. Defaults to None.
|
|
217
|
+
axs (numpy.ndarray, optional): Existing axes to plot on. Defaults to None.
|
|
218
|
+
"""
|
|
219
|
+
|
|
220
|
+
# Create a 2x2 subplot layout
|
|
221
|
+
if fig is None:
|
|
222
|
+
fig, axs = plt.subplots(2, 2, figsize=(12, 10))
|
|
223
|
+
else:
|
|
224
|
+
# Clear previos plots
|
|
225
|
+
for ax in axs.flatten():
|
|
226
|
+
ax.clear()
|
|
227
|
+
|
|
228
|
+
fig.suptitle("Metrics Over Epochs", fontsize=16)
|
|
229
|
+
|
|
230
|
+
# Unpack the data for loss (logged every epoch)
|
|
231
|
+
epochs_loss = [epoch for epoch, _ in results['loss']]
|
|
232
|
+
loss = [value for _, value in results['loss']]
|
|
233
|
+
val_epochs_loss = [epoch for epoch, _ in results['val_loss']]
|
|
234
|
+
val_loss = [value for _,value in results['val_loss']]
|
|
235
|
+
|
|
236
|
+
# Plot Training Loss in the top-left
|
|
237
|
+
axs[0, 0].plot(epochs_loss, loss, label="Training Loss")
|
|
238
|
+
axs[0, 0].plot(val_epochs_loss, val_loss, label='Validation Loss')
|
|
239
|
+
axs[0, 0].set_xlabel("Epochs")
|
|
240
|
+
axs[0, 0].set_ylabel("Loss")
|
|
241
|
+
axs[0, 0].set_title("Training Loss")
|
|
242
|
+
axs[0, 0].legend()
|
|
243
|
+
axs[0, 0].tick_params(axis='both', direction='in', top=True, right=True, length=6, width=1)
|
|
244
|
+
|
|
245
|
+
# For metrics that are logged every `val_interval` epochs
|
|
246
|
+
epochs_metrics = [epoch for epoch, _ in results['avg_recall']]
|
|
247
|
+
|
|
248
|
+
# Determine the number of classes and names
|
|
249
|
+
num_classes = len([key for key in results.keys() if key.startswith('recall_class')])
|
|
250
|
+
|
|
251
|
+
if class_names is None or len(class_names) != num_classes:
|
|
252
|
+
class_names = [f"Class {i+1}" for i in range(num_classes)]
|
|
253
|
+
|
|
254
|
+
# Plot Recall in the top-right
|
|
255
|
+
for class_idx in range(num_classes):
|
|
256
|
+
recall_class = [value for _, value in results[f'recall_class{class_idx+1}']]
|
|
257
|
+
axs[0, 1].plot(epochs_metrics, recall_class, label=f"{class_names[class_idx]}")
|
|
258
|
+
axs[0, 1].set_xlabel("Epochs")
|
|
259
|
+
axs[0, 1].set_ylabel("Recall")
|
|
260
|
+
axs[0, 1].set_title("Recall per Class")
|
|
261
|
+
# axs[0, 1].legend()
|
|
262
|
+
axs[0, 1].tick_params(axis='both', direction='in', top=True, right=True, length=6, width=1)
|
|
263
|
+
|
|
264
|
+
# Plot Precision in the bottom-left
|
|
265
|
+
for class_idx in range(num_classes):
|
|
266
|
+
precision_class = [value for _, value in results[f'precision_class{class_idx+1}']]
|
|
267
|
+
axs[1, 0].plot(epochs_metrics, precision_class, label=f"{class_names[class_idx]}")
|
|
268
|
+
axs[1, 0].set_xlabel("Epochs")
|
|
269
|
+
axs[1, 0].set_ylabel("Precision")
|
|
270
|
+
axs[1, 0].set_title("Precision per Class")
|
|
271
|
+
axs[1, 0].legend()
|
|
272
|
+
axs[1, 0].tick_params(axis='both', direction='in', top=True, right=True, length=6, width=1)
|
|
273
|
+
|
|
274
|
+
# Plot F1 Score in the bottom-right
|
|
275
|
+
for class_idx in range(num_classes):
|
|
276
|
+
f1_class = [value for _, value in results[f'f1_class{class_idx+1}']]
|
|
277
|
+
axs[1, 1].plot(epochs_metrics, f1_class, label=f"{class_names[class_idx]}")
|
|
278
|
+
axs[1, 1].set_xlabel("Epochs")
|
|
279
|
+
axs[1, 1].set_ylabel("F1 Score")
|
|
280
|
+
axs[1, 1].set_title("F1 Score per Class")
|
|
281
|
+
# axs[1, 1].legend()
|
|
282
|
+
axs[1, 1].tick_params(axis='both', direction='in', top=True, right=True, length=6, width=1)
|
|
283
|
+
|
|
284
|
+
# Adjust layout and show plot
|
|
285
|
+
plt.tight_layout(rect=[0, 0, 1, 0.96]) # Leave space for the main title
|
|
286
|
+
|
|
287
|
+
fig.savefig(save_plot)
|
|
288
|
+
fig.canvas.draw()
|
|
289
|
+
|
|
290
|
+
return fig, axs
|
octopi/workflows.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
from octopi.extract.localize import process_localization
|
|
2
|
+
import octopi.processing.evaluate as octopi_evaluate
|
|
3
|
+
from monai.metrics import ConfusionMatrixMetric
|
|
4
|
+
from octopi.models import common as builder
|
|
5
|
+
from octopi.pytorch import segmentation
|
|
6
|
+
from octopi.pytorch import trainer
|
|
7
|
+
from octopi.utils import io
|
|
8
|
+
import multiprocess as mp
|
|
9
|
+
import copick, torch, os
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
|
|
12
|
+
def train(data_generator, loss_function, num_crops = 16,
|
|
13
|
+
model_config = None, model_weights = None, lr0 = 1e-3,
|
|
14
|
+
model_save_path = 'results', best_metric = 'fBeta2',
|
|
15
|
+
num_epochs = 1000, use_ema = True, val_interval = 10,
|
|
16
|
+
sw_bs = 4, overlap = 0.5, ):
|
|
17
|
+
"""
|
|
18
|
+
Train a UNet Model for Segmentation
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
config (str): Path to the Copick Config File
|
|
22
|
+
target_info (list): List containing the target user ID, target session ID, and target algorithm
|
|
23
|
+
tomo_algorithm (str): The tomographic algorithm to use for segmentation
|
|
24
|
+
voxel_size (float): The voxel size of the data
|
|
25
|
+
loss_function (str): The loss function to use for training
|
|
26
|
+
model_config (dict): The model configuration
|
|
27
|
+
model_weights (str): The path to the model weights
|
|
28
|
+
trainRunIDs (list): The list of run IDs to use for training
|
|
29
|
+
validateRunIDs (list): The list of run IDs to use for validation
|
|
30
|
+
model_save_path (str): The path to save the model
|
|
31
|
+
best_metric (str): The metric to use for early stopping
|
|
32
|
+
num_epochs (int): The number of epochs to train for
|
|
33
|
+
val_interval (int): The number of epoch intervals for validation during training
|
|
34
|
+
sw_bs (int): The sliding window batch size for validation
|
|
35
|
+
overlap (float): The overlap for sliding window inference during validation
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
# If No Model Configuration is Provided, Use the Default Configuration
|
|
39
|
+
if model_config is None:
|
|
40
|
+
model_config = {
|
|
41
|
+
'architecture': 'Unet',
|
|
42
|
+
'num_classes': data_generator.Nclasses,
|
|
43
|
+
'dim_in': 80,
|
|
44
|
+
'strides': [2, 2, 1],
|
|
45
|
+
'channels': [48, 64, 80, 80],
|
|
46
|
+
'dropout': 0.0, 'num_res_units': 1,
|
|
47
|
+
}
|
|
48
|
+
print('No Model Configuration Provided, Using Default Configuration')
|
|
49
|
+
print(model_config)
|
|
50
|
+
# extract the model config from full config dict
|
|
51
|
+
elif isinstance(model_config, dict) and 'model' in model_config:
|
|
52
|
+
model_config = model_config['model']
|
|
53
|
+
|
|
54
|
+
# Monai Functions
|
|
55
|
+
metrics_function = ConfusionMatrixMetric(include_background=False, metric_name=["recall",'precision','f1 score'], reduction="none")
|
|
56
|
+
|
|
57
|
+
# Build the Model
|
|
58
|
+
model_builder = builder.get_model(model_config['architecture'])
|
|
59
|
+
model = model_builder.build_model(model_config)
|
|
60
|
+
|
|
61
|
+
# Load the Model Weights if Provided
|
|
62
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
63
|
+
if model_weights:
|
|
64
|
+
print(f'Loading Model Weights from: {model_weights}\n')
|
|
65
|
+
state_dict = torch.load(model_weights, map_location=device, weights_only=True)
|
|
66
|
+
model.load_state_dict(state_dict)
|
|
67
|
+
model.to(device)
|
|
68
|
+
|
|
69
|
+
# Optimizer
|
|
70
|
+
optimizer = torch.optim.AdamW(
|
|
71
|
+
model.parameters(), lr=lr0, weight_decay=1e-4
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Create UNet-Trainer
|
|
75
|
+
model_trainer = trainer.ModelTrainer(
|
|
76
|
+
model, device, loss_function, metrics_function,
|
|
77
|
+
optimizer, use_ema = use_ema
|
|
78
|
+
)
|
|
79
|
+
model_trainer.sw_bs = sw_bs
|
|
80
|
+
model_trainer.overlap = overlap
|
|
81
|
+
|
|
82
|
+
# Train the Model
|
|
83
|
+
print(f'🔃 Starting Training...\nSaving Training Results to: {model_save_path}/\n')
|
|
84
|
+
results = model_trainer.train(
|
|
85
|
+
data_generator, model_save_path, max_epochs=num_epochs,
|
|
86
|
+
crop_size=model_config['dim_in'], my_num_samples=num_crops,
|
|
87
|
+
val_interval=val_interval, best_metric=best_metric, verbose=True
|
|
88
|
+
)
|
|
89
|
+
print('✅ Training Complete!')
|
|
90
|
+
|
|
91
|
+
# Save parameters and results
|
|
92
|
+
print(f'💾 Saving Training Parameters and Results to: {model_save_path}/\n')
|
|
93
|
+
parameters_save_name = os.path.join(model_save_path, "model_config.yaml")
|
|
94
|
+
io.save_parameters_to_yaml(model_builder, model_trainer, data_generator, parameters_save_name)
|
|
95
|
+
|
|
96
|
+
# TODO: Write Results to CSV...
|
|
97
|
+
results_save_name = os.path.join(model_save_path, "results.csv")
|
|
98
|
+
io.save_results_to_csv(results, results_save_name)
|
|
99
|
+
|
|
100
|
+
def segment(config, tomo_algorithm, voxel_size, model_weights, model_config,
|
|
101
|
+
seg_info = ['predict', 'octopi', '1'], use_tta = False, run_ids = None):
|
|
102
|
+
"""
|
|
103
|
+
Segment a Dataset using a Trained Model or Ensemble of Models
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
config (str): Path to the Copick Config File
|
|
107
|
+
tomo_algorithm (str): The tomographic algorithm to use for segmentation
|
|
108
|
+
voxel_size (float): The voxel size of the data
|
|
109
|
+
model_weights (str, list): The path to the model weights or a list of paths to the model weights
|
|
110
|
+
model_config (str, list): The model configuration or a list of model configurations
|
|
111
|
+
seg_info (list): The segmentation information
|
|
112
|
+
use_tta (bool): Whether to use test time augmentation
|
|
113
|
+
run_ids (list): The list of run IDs to use for segmentation
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
# Initialize the Predictor
|
|
117
|
+
predict = segmentation.Predictor(
|
|
118
|
+
config, model_config, model_weights,
|
|
119
|
+
apply_tta = use_tta
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# Run batch prediction
|
|
123
|
+
predict.batch_predict(
|
|
124
|
+
runIDs=run_ids,
|
|
125
|
+
num_tomos_per_batch=15,
|
|
126
|
+
tomo_algorithm=tomo_algorithm,
|
|
127
|
+
voxel_spacing=voxel_size,
|
|
128
|
+
segmentation_name=seg_info[0],
|
|
129
|
+
segmentation_user_id=seg_info[1],
|
|
130
|
+
segmentation_session_id=seg_info[2]
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
print('✅ Segmentation Complete!')
|
|
134
|
+
|
|
135
|
+
def localize(config, voxel_size, seg_info, pick_user_id, pick_session_id, n_procs = 16,
|
|
136
|
+
method = 'watershed', filter_size = 10, radius_min_scale = 0.4, radius_max_scale = 1.0,
|
|
137
|
+
run_ids = None, pick_objects = None):
|
|
138
|
+
"""
|
|
139
|
+
Extract 3D Coordinates from the Segmentation Maps
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
config (str): Path to the Copick Config File
|
|
143
|
+
voxel_size (float): The voxel size of the data
|
|
144
|
+
seg_info (list): The segmentation information
|
|
145
|
+
pick_user_id (str): The user ID of the pick
|
|
146
|
+
pick_session_id (str): The session ID of the pick
|
|
147
|
+
n_procs (int): The number of processes to use for parallelization
|
|
148
|
+
method (str): The method to use for localization
|
|
149
|
+
filter_size (int): The filter size to use for localization
|
|
150
|
+
radius_min_scale (float): The minimum radius scale to use for localization
|
|
151
|
+
radius_max_scale (float): The maximum radius scale to use for localization
|
|
152
|
+
run_ids (list): The list of run IDs to use for localization
|
|
153
|
+
"""
|
|
154
|
+
|
|
155
|
+
# Load the Copick Config
|
|
156
|
+
root = copick.from_file(config)
|
|
157
|
+
|
|
158
|
+
# Get objects that can be Picked build into mutable rows
|
|
159
|
+
objects = [[obj.name, int(obj.label), float(obj.radius)]
|
|
160
|
+
for obj in root.pickable_objects if obj.is_particle]
|
|
161
|
+
|
|
162
|
+
# Verify each object has the required attributes
|
|
163
|
+
for obj in objects:
|
|
164
|
+
if len(obj) < 3 or not isinstance(obj[2], (float, int)):
|
|
165
|
+
raise ValueError(f"Invalid object format: {obj}. Expected a tuple with (name, label, radius).")
|
|
166
|
+
|
|
167
|
+
# Load the Model Output Configuration
|
|
168
|
+
seg_config = io.get_config(config, seg_info[0], 'segment', seg_info[1], seg_info[2])
|
|
169
|
+
|
|
170
|
+
# sync labels from the model config and remove objects not in model labels
|
|
171
|
+
label_map = seg_config.get('labels', {})
|
|
172
|
+
for row in objects.copy(): # avoid modifying the list while iterating
|
|
173
|
+
name, label, radius = row
|
|
174
|
+
if name in label_map and label != label_map[name]:
|
|
175
|
+
row[1] = int(label_map[name]) # mutate in place
|
|
176
|
+
elif name not in label_map: # remove this entry from objects
|
|
177
|
+
objects.remove(row)
|
|
178
|
+
|
|
179
|
+
# Filter objects based on the provided list
|
|
180
|
+
if pick_objects is not None:
|
|
181
|
+
objects0 = objects.copy() # avoid modifying the list while iterating for error tracking
|
|
182
|
+
objects = [obj for obj in objects if obj[0] in pick_objects]
|
|
183
|
+
if len(objects) == 0:
|
|
184
|
+
raise ValueError(f"No valid objects found for localization after filtering. Mismatched names: {pick_objects} and {[obj[0] for obj in objects0]}")
|
|
185
|
+
|
|
186
|
+
# Get all RunIDs
|
|
187
|
+
if run_ids is None:
|
|
188
|
+
run_ids = [run.name for run in root.runs]
|
|
189
|
+
n_run_ids = len(run_ids)
|
|
190
|
+
|
|
191
|
+
# Exit if No Runs are Available
|
|
192
|
+
if n_run_ids == 0:
|
|
193
|
+
print(f"No runs available for localization with the specified voxel size - {voxel_size}.\nExiting...")
|
|
194
|
+
return
|
|
195
|
+
|
|
196
|
+
# Run Localization - Main Parallelization Loop
|
|
197
|
+
print(f"Using {n_procs} processes to parallelize across {n_run_ids} run IDs.")
|
|
198
|
+
with mp.Pool(processes=n_procs) as pool:
|
|
199
|
+
with tqdm(total=n_run_ids, desc="Localization", unit="run") as pbar:
|
|
200
|
+
worker_func = lambda run_id: process_localization(
|
|
201
|
+
root.get_run(run_id),
|
|
202
|
+
objects,
|
|
203
|
+
seg_info,
|
|
204
|
+
method,
|
|
205
|
+
voxel_size,
|
|
206
|
+
filter_size,
|
|
207
|
+
radius_min_scale,
|
|
208
|
+
radius_max_scale,
|
|
209
|
+
pick_session_id,
|
|
210
|
+
pick_user_id
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
for _ in pool.imap_unordered(worker_func, run_ids, chunksize=1):
|
|
214
|
+
pbar.update(1)
|
|
215
|
+
|
|
216
|
+
print('✅ Localization Complete!')
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def evaluate(config,
|
|
220
|
+
gt_user_id, gt_session_id,
|
|
221
|
+
pred_user_id, pred_session_id,
|
|
222
|
+
run_ids = None, distance_threshold = 0.5, save_path = None):
|
|
223
|
+
"""
|
|
224
|
+
Evaluate the Localization on a Dataset
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
config (str): Path to the Copick Config File
|
|
228
|
+
gt_user_id (str): The user ID of the ground truth
|
|
229
|
+
gt_session_id (str): The session ID of the ground truth
|
|
230
|
+
pred_user_id (str): The user ID of the predicted coordinates
|
|
231
|
+
pred_session_id (str): The session ID of the predicted coordinates
|
|
232
|
+
run_ids (list): The list of run IDs to use for evaluation
|
|
233
|
+
distance_threshold (float): The distance threshold to use for evaluation
|
|
234
|
+
save_path (str): The path to save the evaluation results
|
|
235
|
+
"""
|
|
236
|
+
|
|
237
|
+
print('Running Evaluation on the Following Query:')
|
|
238
|
+
print(f'Distance Threshold: {distance_threshold}')
|
|
239
|
+
print(f'GT User ID: {gt_user_id}, GT Session ID: {gt_session_id}')
|
|
240
|
+
print(f'Pred User ID: {pred_user_id}, Pred Session ID: {pred_session_id}')
|
|
241
|
+
print(f'Run IDs: {run_ids}')
|
|
242
|
+
|
|
243
|
+
# Load the Copick Config
|
|
244
|
+
root = copick.from_file(config)
|
|
245
|
+
|
|
246
|
+
# For Now Lets Assume Object Names are None..
|
|
247
|
+
object_names = None
|
|
248
|
+
|
|
249
|
+
# Run Evaluation
|
|
250
|
+
eval = octopi_evaluate.evaluator(
|
|
251
|
+
config,
|
|
252
|
+
gt_user_id,
|
|
253
|
+
gt_session_id,
|
|
254
|
+
pred_user_id,
|
|
255
|
+
pred_session_id,
|
|
256
|
+
object_names=object_names
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
eval.run(
|
|
260
|
+
distance_threshold_scale=distance_threshold,
|
|
261
|
+
runIDs=run_ids, save_path=save_path
|
|
262
|
+
)
|