octopi 1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of octopi might be problematic. Click here for more details.
- octopi/__init__.py +1 -0
- octopi/datasets/cached_datset.py +1 -1
- octopi/datasets/generators.py +1 -1
- octopi/datasets/io.py +200 -0
- octopi/datasets/multi_config_generator.py +1 -1
- octopi/entry_points/common.py +9 -9
- octopi/entry_points/create_slurm_submission.py +16 -8
- octopi/entry_points/run_create_targets.py +6 -6
- octopi/entry_points/run_evaluate.py +4 -3
- octopi/entry_points/run_extract_mb_picks.py +22 -45
- octopi/entry_points/run_localize.py +37 -54
- octopi/entry_points/run_optuna.py +7 -7
- octopi/entry_points/run_segment_predict.py +4 -4
- octopi/entry_points/run_train.py +7 -8
- octopi/extract/localize.py +19 -12
- octopi/extract/membranebound_extract.py +11 -10
- octopi/extract/midpoint_extract.py +3 -3
- octopi/main.py +1 -1
- octopi/models/common.py +1 -1
- octopi/processing/create_targets_from_picks.py +11 -5
- octopi/processing/downsample.py +6 -10
- octopi/processing/evaluate.py +24 -11
- octopi/processing/importers.py +4 -4
- octopi/pytorch/hyper_search.py +2 -3
- octopi/pytorch/model_search_submitter.py +15 -15
- octopi/pytorch/segmentation.py +147 -192
- octopi/pytorch/segmentation_multigpu.py +162 -0
- octopi/pytorch/trainer.py +9 -3
- octopi/utils/__init__.py +0 -0
- octopi/utils/config.py +57 -0
- octopi/utils/io.py +128 -0
- octopi/{utils.py → utils/parsers.py} +10 -84
- octopi/{stopping_criteria.py → utils/stopping_criteria.py} +3 -3
- octopi/{visualization_tools.py → utils/visualization_tools.py} +4 -4
- octopi/workflows.py +236 -0
- octopi-1.2.0.dist-info/METADATA +120 -0
- octopi-1.2.0.dist-info/RECORD +62 -0
- {octopi-1.0.dist-info → octopi-1.2.0.dist-info}/WHEEL +1 -1
- octopi-1.2.0.dist-info/entry_points.txt +3 -0
- {octopi-1.0.dist-info → octopi-1.2.0.dist-info/licenses}/LICENSE +3 -3
- octopi/io.py +0 -457
- octopi/processing/my_metrics.py +0 -26
- octopi/processing/writers.py +0 -102
- octopi-1.0.dist-info/METADATA +0 -209
- octopi-1.0.dist-info/RECORD +0 -59
- octopi-1.0.dist-info/entry_points.txt +0 -4
- /octopi/{losses.py → utils/losses.py} +0 -0
- /octopi/{submit_slurm.py → utils/submit_slurm.py} +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from octopi.pytorch.model_search_submitter import ModelSearchSubmit
|
|
2
2
|
from octopi.entry_points import common
|
|
3
|
+
from octopi.utils import parsers, io
|
|
3
4
|
import argparse, os, pprint
|
|
4
|
-
from octopi import utils
|
|
5
5
|
|
|
6
6
|
def optuna_parser(parser_description, add_slurm: bool = False):
|
|
7
7
|
"""
|
|
@@ -20,22 +20,22 @@ def optuna_parser(parser_description, add_slurm: bool = False):
|
|
|
20
20
|
# Input Arguments
|
|
21
21
|
input_group = parser.add_argument_group("Input Arguments")
|
|
22
22
|
common.add_config(input_group, single_config=False)
|
|
23
|
-
input_group.add_argument("--target-info", type=
|
|
23
|
+
input_group.add_argument("--target-info", type=parsers.parse_target, default="targets,octopi,1",
|
|
24
24
|
help="Target information, e.g., 'name' or 'name,user_id,session_id'")
|
|
25
25
|
input_group.add_argument("--tomo-alg", default='wbp',
|
|
26
26
|
help="Tomogram algorithm used for training")
|
|
27
27
|
input_group.add_argument("--mlflow-experiment-name", type=str, default="model-search", required=False,
|
|
28
28
|
help="Name of the MLflow experiment (default: 'model-search').")
|
|
29
|
-
input_group.add_argument("--trainRunIDs", type=
|
|
29
|
+
input_group.add_argument("--trainRunIDs", type=parsers.parse_list, default=None, required=False,
|
|
30
30
|
help="List of training run IDs, e.g., run1,run2 or [run1,run2].")
|
|
31
|
-
input_group.add_argument("--validateRunIDs", type=
|
|
31
|
+
input_group.add_argument("--validateRunIDs", type=parsers.parse_list, default=None, required=False,
|
|
32
32
|
help="List of validation run IDs, e.g., run3,run4 or [run3,run4].")
|
|
33
33
|
input_group.add_argument('--data-split', type=str, default='0.8', help="Data split ratios. Either a single value (e.g., '0.8' for 80/20/0 split) "
|
|
34
34
|
"or two comma-separated values (e.g., '0.7,0.1' for 70/10/20 split)")
|
|
35
35
|
|
|
36
36
|
model_group = parser.add_argument_group("Model Arguments")
|
|
37
37
|
model_group.add_argument("--model-type", type=str, default='Unet', required=False,
|
|
38
|
-
choices=['Unet', 'AttentionUnet'],
|
|
38
|
+
choices=['Unet', 'AttentionUnet', 'MedNeXt', 'SegResNet'],
|
|
39
39
|
help="Model type to use for training")
|
|
40
40
|
model_group.add_argument("--Nclass", type=int, default=3, required=False, help="Number of prediction classes in the model")
|
|
41
41
|
|
|
@@ -61,7 +61,7 @@ def cli():
|
|
|
61
61
|
args = optuna_parser(description)
|
|
62
62
|
|
|
63
63
|
# Parse the CoPick configuration paths
|
|
64
|
-
if len(args.config) > 1: copick_configs =
|
|
64
|
+
if len(args.config) > 1: copick_configs = parsers.parse_copick_configs(args.config)
|
|
65
65
|
else: copick_configs = args.config[0]
|
|
66
66
|
|
|
67
67
|
# Create the model exploration directory
|
|
@@ -133,7 +133,7 @@ def save_parameters(args: argparse.Namespace,
|
|
|
133
133
|
pprint.pprint(params); print()
|
|
134
134
|
|
|
135
135
|
# Save to YAML file
|
|
136
|
-
|
|
136
|
+
io.save_parameters_yaml(params, output_path)
|
|
137
137
|
|
|
138
138
|
if __name__ == "__main__":
|
|
139
139
|
cli()
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
+
import torch, argparse, json, pprint, yaml, os
|
|
1
2
|
from octopi.pytorch import segmentation
|
|
2
3
|
from octopi.entry_points import common
|
|
3
|
-
import torch, argparse, json, pprint, yaml, os
|
|
4
|
-
from octopi import utils
|
|
5
4
|
from typing import List, Tuple
|
|
5
|
+
from octopi.utils import io
|
|
6
6
|
|
|
7
7
|
def inference(
|
|
8
8
|
copick_config_path: str,
|
|
@@ -136,7 +136,7 @@ def save_parameters(args: argparse.Namespace,
|
|
|
136
136
|
output_path: str):
|
|
137
137
|
|
|
138
138
|
# Load the model config
|
|
139
|
-
model_config =
|
|
139
|
+
model_config = io.load_yaml(args.model_config)
|
|
140
140
|
|
|
141
141
|
# Create parameters dictionary
|
|
142
142
|
params = {
|
|
@@ -160,7 +160,7 @@ def save_parameters(args: argparse.Namespace,
|
|
|
160
160
|
pprint.pprint(params); print()
|
|
161
161
|
|
|
162
162
|
# Save to YAML file
|
|
163
|
-
|
|
163
|
+
io.save_parameters_yaml(params, output_path)
|
|
164
164
|
|
|
165
165
|
if __name__ == "__main__":
|
|
166
166
|
cli()
|
octopi/entry_points/run_train.py
CHANGED
|
@@ -2,12 +2,11 @@ from octopi.datasets import generators, multi_config_generator
|
|
|
2
2
|
from monai.losses import DiceLoss, FocalLoss, TverskyLoss
|
|
3
3
|
from octopi.models import common as builder
|
|
4
4
|
from monai.metrics import ConfusionMatrixMetric
|
|
5
|
+
from octopi.utils import parsers, io
|
|
5
6
|
from octopi.entry_points import common
|
|
6
7
|
from octopi.pytorch import trainer
|
|
7
|
-
from octopi import io, utils
|
|
8
8
|
import torch, os, argparse
|
|
9
9
|
from typing import List, Optional, Tuple
|
|
10
|
-
import pprint
|
|
11
10
|
|
|
12
11
|
def train_model(
|
|
13
12
|
copick_config_path: str,
|
|
@@ -56,7 +55,7 @@ def train_model(
|
|
|
56
55
|
|
|
57
56
|
|
|
58
57
|
# Get the data splits
|
|
59
|
-
ratios =
|
|
58
|
+
ratios = parsers.parse_data_split(data_split)
|
|
60
59
|
data_generator.get_data_splits(
|
|
61
60
|
trainRunIDs = trainRunIDs,
|
|
62
61
|
validateRunIDs = validateRunIDs,
|
|
@@ -114,11 +113,11 @@ def train_model_parser(parser_description, add_slurm: bool = False):
|
|
|
114
113
|
# Input Arguments
|
|
115
114
|
input_group = parser.add_argument_group("Input Arguments")
|
|
116
115
|
common.add_config(input_group, single_config=False)
|
|
117
|
-
input_group.add_argument("--target-info", type=
|
|
116
|
+
input_group.add_argument("--target-info", type=parsers.parse_target, default="targets,octopi,1",
|
|
118
117
|
help="Target information, e.g., 'name' or 'name,user_id,session_id'. Default is 'targets,octopi,1'.")
|
|
119
118
|
input_group.add_argument("--tomo-alg", default='wbp', help="Tomogram algorithm used for training")
|
|
120
|
-
input_group.add_argument("--trainRunIDs", type=
|
|
121
|
-
input_group.add_argument("--validateRunIDs", type=
|
|
119
|
+
input_group.add_argument("--trainRunIDs", type=parsers.parse_list, help="List of training run IDs, e.g., run1,run2,run3")
|
|
120
|
+
input_group.add_argument("--validateRunIDs", type=parsers.parse_list, help="List of validation run IDs, e.g., run4,run5,run6")
|
|
122
121
|
input_group.add_argument('--data-split', type=str, default='0.8', help="Data split ratios. Either a single value (e.g., '0.8' for 80/20/0 split) "
|
|
123
122
|
"or two comma-separated values (e.g., '0.7,0.1' for 70/10/20 split)")
|
|
124
123
|
|
|
@@ -153,11 +152,11 @@ def cli():
|
|
|
153
152
|
args = train_model_parser(parser_description)
|
|
154
153
|
|
|
155
154
|
# Parse the CoPick configuration paths
|
|
156
|
-
if len(args.config) > 1: copick_configs =
|
|
155
|
+
if len(args.config) > 1: copick_configs = parsers.parse_copick_configs(args.config)
|
|
157
156
|
else: copick_configs = args.config[0]
|
|
158
157
|
|
|
159
158
|
if args.model_config:
|
|
160
|
-
model_config =
|
|
159
|
+
model_config = io.load_yaml(args.model_config)
|
|
161
160
|
else:
|
|
162
161
|
model_config = get_model_config(args.channels, args.strides, args.res_units, args.Nclass, args.dim_in)
|
|
163
162
|
|
octopi/extract/localize.py
CHANGED
|
@@ -3,15 +3,15 @@ from scipy.cluster.hierarchy import fcluster, linkage
|
|
|
3
3
|
from skimage.segmentation import watershed
|
|
4
4
|
from typing import List, Optional, Tuple
|
|
5
5
|
from skimage.measure import regionprops
|
|
6
|
+
from copick_utils.io import readers
|
|
6
7
|
from scipy.spatial import distance
|
|
7
8
|
from dataclasses import dataclass
|
|
8
|
-
from octopi import io
|
|
9
9
|
import scipy.ndimage as ndi
|
|
10
10
|
from tqdm import tqdm
|
|
11
11
|
import numpy as np
|
|
12
|
-
import
|
|
12
|
+
import gc
|
|
13
13
|
|
|
14
|
-
def
|
|
14
|
+
def process_localization(run,
|
|
15
15
|
objects,
|
|
16
16
|
seg_info: Tuple[str, str, str],
|
|
17
17
|
method: str = 'com',
|
|
@@ -27,12 +27,12 @@ def processs_localization(run,
|
|
|
27
27
|
raise ValueError(f"Invalid method '{method}'. Expected 'watershed' or 'com'.")
|
|
28
28
|
|
|
29
29
|
# Get Segmentation
|
|
30
|
-
seg =
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
30
|
+
seg = readers.segmentation(
|
|
31
|
+
run, voxel_size,
|
|
32
|
+
seg_info[0],
|
|
33
|
+
user_id=seg_info[1],
|
|
34
|
+
session_id=seg_info[2],
|
|
35
|
+
raise_error=False)
|
|
36
36
|
|
|
37
37
|
# Preprocess Segmentation
|
|
38
38
|
# seg = preprocess_segmentation(seg, voxel_size, objects)
|
|
@@ -99,15 +99,15 @@ def extract_particle_centroids_via_watershed(
|
|
|
99
99
|
max_particle_size (int): Maximum size threshold for particles.
|
|
100
100
|
"""
|
|
101
101
|
|
|
102
|
-
if maxima_filter_size is None or maxima_filter_size
|
|
103
|
-
|
|
102
|
+
if maxima_filter_size is None or maxima_filter_size <= 0:
|
|
103
|
+
raise ValueError('Enter a Non-Zero Filter Size!')
|
|
104
104
|
|
|
105
105
|
# Calculate minimum and maximum particle volumes based on the given radii
|
|
106
106
|
min_particle_size = (4 / 3) * np.pi * (min_particle_radius ** 3)
|
|
107
107
|
max_particle_size = (4 / 3) * np.pi * (max_particle_radius ** 3)
|
|
108
108
|
|
|
109
109
|
# Create a binary mask for the specific segmentation label
|
|
110
|
-
binary_mask = (segmentation == segmentation_idx).astype(
|
|
110
|
+
binary_mask = (segmentation == segmentation_idx).astype(np.uint8)
|
|
111
111
|
|
|
112
112
|
# Skip if the segmentation label is not present
|
|
113
113
|
if np.sum(binary_mask) == 0:
|
|
@@ -117,6 +117,7 @@ def extract_particle_centroids_via_watershed(
|
|
|
117
117
|
# Structuring element for erosion and dilation
|
|
118
118
|
struct_elem = ball(1)
|
|
119
119
|
eroded = binary_erosion(binary_mask, struct_elem)
|
|
120
|
+
|
|
120
121
|
dilated = binary_dilation(eroded, struct_elem)
|
|
121
122
|
|
|
122
123
|
# Distance transform and local maxima detection
|
|
@@ -125,7 +126,13 @@ def extract_particle_centroids_via_watershed(
|
|
|
125
126
|
|
|
126
127
|
# Watershed segmentation
|
|
127
128
|
markers, _ = ndi.label(local_max)
|
|
129
|
+
del local_max
|
|
130
|
+
gc.collect()
|
|
131
|
+
|
|
128
132
|
watershed_labels = watershed(-distance, markers, mask=dilated)
|
|
133
|
+
distance, markers, dilated = None, None, None
|
|
134
|
+
del distance, markers, dilated
|
|
135
|
+
gc.collect()
|
|
129
136
|
|
|
130
137
|
# Extract region properties and filter based on particle size
|
|
131
138
|
all_centroids = []
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from scipy.spatial.transform import Rotation as R
|
|
2
|
-
from
|
|
2
|
+
from copick_utils.io import readers
|
|
3
3
|
import scipy.ndimage as ndi
|
|
4
4
|
from typing import Tuple
|
|
5
5
|
import numpy as np
|
|
@@ -36,7 +36,7 @@ def process_membrane_bound_extract(run,
|
|
|
36
36
|
new_session_id = str(int(save_session_id) + 1) # Convert to string after increment
|
|
37
37
|
|
|
38
38
|
# Need Better Error Handing for Missing Picks
|
|
39
|
-
coordinates =
|
|
39
|
+
coordinates = readers.coordinates(
|
|
40
40
|
run,
|
|
41
41
|
picks_info[0], picks_info[1], picks_info[2],
|
|
42
42
|
voxel_size,
|
|
@@ -54,12 +54,13 @@ def process_membrane_bound_extract(run,
|
|
|
54
54
|
if membrane_info is None:
|
|
55
55
|
# Flag to distinguish between organelle and membrane segmentation
|
|
56
56
|
membranes_provided = False
|
|
57
|
-
seg =
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
57
|
+
seg = readers.segmentation(
|
|
58
|
+
run,
|
|
59
|
+
voxel_size,
|
|
60
|
+
organelle_info[0],
|
|
61
|
+
user_id=organelle_info[1],
|
|
62
|
+
session_id=organelle_info[2],
|
|
63
|
+
raise_error=False)
|
|
63
64
|
# If No Segmentation is Found, Return
|
|
64
65
|
if seg is None: return
|
|
65
66
|
elif nPoints == 0 or np.unique(seg).max() == 0:
|
|
@@ -68,7 +69,7 @@ def process_membrane_bound_extract(run,
|
|
|
68
69
|
else:
|
|
69
70
|
# Read both Organelle and Membrane Segmentations
|
|
70
71
|
membranes_provided = True
|
|
71
|
-
seg =
|
|
72
|
+
seg = readers.segmentation(
|
|
72
73
|
run,
|
|
73
74
|
voxel_size,
|
|
74
75
|
membrane_info[0],
|
|
@@ -76,7 +77,7 @@ def process_membrane_bound_extract(run,
|
|
|
76
77
|
session_id=membrane_info[2],
|
|
77
78
|
raise_error=False)
|
|
78
79
|
|
|
79
|
-
organelle_seg =
|
|
80
|
+
organelle_seg = readers.segmentation(
|
|
80
81
|
run,
|
|
81
82
|
voxel_size,
|
|
82
83
|
organelle_info[0],
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from octopi.extract import membranebound_extract as extract
|
|
2
2
|
from scipy.spatial.transform import Rotation as R
|
|
3
|
-
from
|
|
3
|
+
from copick_utils.io import readers
|
|
4
4
|
from scipy.spatial import cKDTree
|
|
5
5
|
from typing import Tuple
|
|
6
6
|
import numpy as np
|
|
@@ -28,7 +28,7 @@ def process_midpoint_extract(
|
|
|
28
28
|
"""
|
|
29
29
|
|
|
30
30
|
# Pull Picks that Are used for Midpoint Extraction
|
|
31
|
-
coordinates =
|
|
31
|
+
coordinates = readers.coordinates(
|
|
32
32
|
run,
|
|
33
33
|
picks_info[0], picks_info[1], picks_info[2],
|
|
34
34
|
voxel_size
|
|
@@ -40,7 +40,7 @@ def process_midpoint_extract(
|
|
|
40
40
|
save_picks_info[2] = save_session_id
|
|
41
41
|
|
|
42
42
|
# Get Organelle Segmentation
|
|
43
|
-
seg =
|
|
43
|
+
seg = readers.segmentation(
|
|
44
44
|
run,
|
|
45
45
|
voxel_size,
|
|
46
46
|
organelle_info[0],
|
octopi/main.py
CHANGED
|
@@ -33,7 +33,7 @@ def cli_main():
|
|
|
33
33
|
"create-targets": (create_targets, "Generate segmentation targets from coordinates."),
|
|
34
34
|
"train": (train_model, "Train a single U-Net model."),
|
|
35
35
|
"model-explore": (model_explore, "Explore model architectures with Optuna / Bayesian Optimization."),
|
|
36
|
-
"
|
|
36
|
+
"segment": (inference, "Perform segmentation inference on tomograms."),
|
|
37
37
|
"localize": (localize, "Perform localization of particles in tomograms."),
|
|
38
38
|
"extract-mb-picks": (extract_mb_picks, "Extract MB Picks from tomograms."),
|
|
39
39
|
"evaluate": (evaluate, "Evaluate the performance of a model."),
|
octopi/models/common.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from octopi.processing.segmentation_from_picks import from_picks
|
|
2
|
-
|
|
3
|
-
from octopi import io
|
|
2
|
+
from copick_utils.io import readers, writers
|
|
4
3
|
from typing import List
|
|
5
4
|
from tqdm import tqdm
|
|
6
5
|
import numpy as np
|
|
@@ -42,7 +41,11 @@ def generate_targets(
|
|
|
42
41
|
|
|
43
42
|
# If runIDs are not provided, load all runs
|
|
44
43
|
if run_ids is None:
|
|
45
|
-
run_ids = [run.name for run in root.runs]
|
|
44
|
+
run_ids = [run.name for run in root.runs if run.get_voxel_spacing(voxel_size) is not None]
|
|
45
|
+
skipped_run_ids = [run.name for run in root.runs if run.get_voxel_spacing(voxel_size) is None]
|
|
46
|
+
|
|
47
|
+
if skipped_run_ids:
|
|
48
|
+
print(f"Warning: skipping runs with no voxel spacing {voxel_size}: {skipped_run_ids}")
|
|
46
49
|
|
|
47
50
|
# Iterate Over All Runs
|
|
48
51
|
for runID in tqdm(run_ids):
|
|
@@ -52,7 +55,7 @@ def generate_targets(
|
|
|
52
55
|
run = root.get_run(runID)
|
|
53
56
|
|
|
54
57
|
# Get Tomogram
|
|
55
|
-
tomo =
|
|
58
|
+
tomo = readers.tomogram(run, voxel_size, tomo_algorithm)
|
|
56
59
|
|
|
57
60
|
# Initialize Target Volume
|
|
58
61
|
target = np.zeros(tomo.shape, dtype=np.uint8)
|
|
@@ -87,6 +90,9 @@ def generate_targets(
|
|
|
87
90
|
session_id=train_targets[target_name]["session_id"],
|
|
88
91
|
)
|
|
89
92
|
|
|
93
|
+
# Filter out empty picks
|
|
94
|
+
query = [pick for pick in query if pick.points is not None]
|
|
95
|
+
|
|
90
96
|
# Add Picks to Target
|
|
91
97
|
for pick in query:
|
|
92
98
|
numPicks += len(pick.points)
|
|
@@ -100,7 +106,7 @@ def generate_targets(
|
|
|
100
106
|
# Write Segmentation for non-empty targets
|
|
101
107
|
if target.max() > 0 and numPicks > 0:
|
|
102
108
|
tqdm.write(f'Annotating {numPicks} picks in {runID}...')
|
|
103
|
-
|
|
109
|
+
writers.segmentation(run, target, target_user_name,
|
|
104
110
|
name = target_segmentation_name, session_id= target_session_id,
|
|
105
111
|
voxel_size = voxel_size)
|
|
106
112
|
print('Creation of targets complete!')
|
octopi/processing/downsample.py
CHANGED
|
@@ -102,11 +102,6 @@ class FourierRescale:
|
|
|
102
102
|
"""
|
|
103
103
|
in_depth, in_height, in_width = volume.shape[-3:]
|
|
104
104
|
|
|
105
|
-
# Check if dimensions are odd
|
|
106
|
-
d_is_odd = in_depth % 2
|
|
107
|
-
h_is_odd = in_height % 2
|
|
108
|
-
w_is_odd = in_width % 2
|
|
109
|
-
|
|
110
105
|
# Calculate new dimensions
|
|
111
106
|
extent_depth = in_depth * self.input_voxel_size[0]
|
|
112
107
|
extent_height = in_height * self.input_voxel_size[1]
|
|
@@ -121,9 +116,10 @@ class FourierRescale:
|
|
|
121
116
|
new_height = new_height - (new_height % 2)
|
|
122
117
|
new_width = new_width - (new_width % 2)
|
|
123
118
|
|
|
124
|
-
# Calculate starting points
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
119
|
+
# Calculate starting points - properly centered around DC component
|
|
120
|
+
# No odd/even correction needed - just center the crop
|
|
121
|
+
start_d = (in_depth - new_depth) // 2
|
|
122
|
+
start_h = (in_height - new_height) // 2
|
|
123
|
+
start_w = (in_width - new_width) // 2
|
|
128
124
|
|
|
129
|
-
return start_d, start_h, start_w, new_depth, new_height, new_width
|
|
125
|
+
return start_d, start_h, start_w, new_depth, new_height, new_width
|
octopi/processing/evaluate.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
from
|
|
1
|
+
from copick_utils.io import readers
|
|
2
2
|
from scipy.spatial import distance
|
|
3
|
+
import copick, json, os, yaml
|
|
3
4
|
from typing import List
|
|
4
|
-
import copick, json, os
|
|
5
5
|
import numpy as np
|
|
6
6
|
|
|
7
7
|
class evaluator:
|
|
@@ -95,12 +95,12 @@ class evaluator:
|
|
|
95
95
|
for name, radius in self.objects:
|
|
96
96
|
|
|
97
97
|
# Get Ground Truth and Predicted Coordinates
|
|
98
|
-
gt_coordinates =
|
|
98
|
+
gt_coordinates = readers.coordinates(
|
|
99
99
|
run, name,
|
|
100
100
|
self.ground_truth_user_id, self.ground_truth_session_id,
|
|
101
101
|
self.voxel_size, raise_error=False
|
|
102
102
|
)
|
|
103
|
-
pred_coordinates =
|
|
103
|
+
pred_coordinates = readers.coordinates(
|
|
104
104
|
run, name,
|
|
105
105
|
self.prediction_user_id, self.predict_session_id,
|
|
106
106
|
self.voxel_size, raise_error=False
|
|
@@ -202,14 +202,27 @@ class evaluator:
|
|
|
202
202
|
}
|
|
203
203
|
|
|
204
204
|
os.makedirs(save_path, exist_ok=True)
|
|
205
|
-
summary_metrics = { "input": self.input_params,
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
205
|
+
summary_metrics = { "input": self.input_params,
|
|
206
|
+
"final_fbeta_score": final_fbeta,
|
|
207
|
+
"aggregated_particle_scores": { # Optionally add per-particle details
|
|
208
|
+
name: {
|
|
209
|
+
"tp": counts['total_tp'],
|
|
210
|
+
"fp": counts['total_fp'],
|
|
211
|
+
"fn": counts['total_fn'],
|
|
212
|
+
"weight": self.weights.get(name, 1)
|
|
213
|
+
} for name, counts in aggregated_counts.items()
|
|
214
|
+
},
|
|
215
|
+
"summary_metrics": final_summary_metrics,
|
|
216
|
+
"parameters": self.parameters, }
|
|
217
|
+
|
|
218
|
+
# Save average metrics to YAML file
|
|
219
|
+
with open(os.path.join(save_path, 'average_metrics.yaml'), 'w') as f:
|
|
220
|
+
yaml.dump(summary_metrics, f, indent=4, default_flow_style=False, sort_keys=False)
|
|
221
|
+
print(f'\nAverage Metrics saved to {os.path.join(save_path, "average_metrics.yaml")}')
|
|
210
222
|
|
|
211
|
-
detailed_metrics = { "input": self.input_params,
|
|
212
|
-
|
|
223
|
+
detailed_metrics = { "input": self.input_params,
|
|
224
|
+
"metrics": metrics,
|
|
225
|
+
"parameters": self.parameters, }
|
|
213
226
|
with open(os.path.join(save_path, 'metrics.json'), 'w') as f:
|
|
214
227
|
json.dump(detailed_metrics, f, indent=4)
|
|
215
228
|
print(f'Metrics saved to {os.path.join(save_path, "metrics.json")}')
|
octopi/processing/importers.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from octopi.processing.downsample import FourierRescale
|
|
2
2
|
import copick, argparse, mrcfile, glob, os
|
|
3
|
-
import octopi.processing.writers as write
|
|
4
3
|
from octopi.entry_points import common
|
|
4
|
+
from copick_utils.io import writers
|
|
5
5
|
from tqdm import tqdm
|
|
6
6
|
|
|
7
7
|
def from_dataportal(
|
|
@@ -57,10 +57,10 @@ def from_dataportal(
|
|
|
57
57
|
|
|
58
58
|
# If we want to save the tomograms at a different voxel size, we need to rescale the tomograms
|
|
59
59
|
if output_voxel_size is None:
|
|
60
|
-
|
|
60
|
+
writers.tomogram(run, vol, input_voxel_size, target_tomo_type)
|
|
61
61
|
else:
|
|
62
62
|
vol = rescale.run(vol)
|
|
63
|
-
|
|
63
|
+
writers.tomogram(run, vol, output_voxel_size, target_tomo_type)
|
|
64
64
|
|
|
65
65
|
print(f'Downloading Complete!! Downloaded {len(root.runs)} runs')
|
|
66
66
|
|
|
@@ -168,7 +168,7 @@ def from_mrcs(
|
|
|
168
168
|
voxel_size_to_write = input_voxel_size
|
|
169
169
|
|
|
170
170
|
# Write the tomogram
|
|
171
|
-
|
|
171
|
+
writers.tomogram(run, vol, voxel_size_to_write, target_tomo_type)
|
|
172
172
|
print(f"Processed {len(mrc_files)} files from {mrcs_path}")
|
|
173
173
|
|
|
174
174
|
|
octopi/pytorch/hyper_search.py
CHANGED
|
@@ -1,10 +1,9 @@
|
|
|
1
|
-
from monai.losses import FocalLoss, TverskyLoss
|
|
2
1
|
from monai.metrics import ConfusionMatrixMetric
|
|
3
2
|
from octopi.pytorch import trainer
|
|
4
3
|
from mlflow.tracking import MlflowClient
|
|
5
4
|
from octopi.models import common
|
|
6
|
-
from octopi import io, losses
|
|
7
5
|
import torch, mlflow, optuna, gc
|
|
6
|
+
from octopi.utils import io
|
|
8
7
|
|
|
9
8
|
class BayesianModelSearch:
|
|
10
9
|
|
|
@@ -207,7 +206,7 @@ class BayesianModelSearch:
|
|
|
207
206
|
if score > best_score_so_far:
|
|
208
207
|
torch.save(model_trainer.model_weights, f'{self.results_dir}/best_model.pth')
|
|
209
208
|
io.save_parameters_to_yaml(self.model_builder, model_trainer, self.data_generator,
|
|
210
|
-
f'{self.results_dir}/
|
|
209
|
+
f'{self.results_dir}/model_config.yaml')
|
|
211
210
|
|
|
212
211
|
def get_best_score(self, trial):
|
|
213
212
|
"""Retrieve the best score from the trial."""
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from octopi.datasets import generators, multi_config_generator
|
|
2
|
+
from octopi.utils import config, parsers
|
|
2
3
|
from octopi.pytorch import hyper_search
|
|
3
4
|
import torch, mlflow, optuna
|
|
4
|
-
from octopi import utils
|
|
5
5
|
from typing import List
|
|
6
6
|
import pandas as pd
|
|
7
7
|
|
|
@@ -16,16 +16,16 @@ class ModelSearchSubmit:
|
|
|
16
16
|
voxel_size: float,
|
|
17
17
|
Nclass: int,
|
|
18
18
|
model_type: str,
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
trainRunIDs: List[str],
|
|
27
|
-
validateRunIDs: List[str],
|
|
28
|
-
|
|
19
|
+
best_metric: str = 'avg_f1',
|
|
20
|
+
num_epochs: int = 1000,
|
|
21
|
+
num_trials: int = 100,
|
|
22
|
+
data_split: str = 0.8,
|
|
23
|
+
random_seed: int = 42,
|
|
24
|
+
val_interval: int = 10,
|
|
25
|
+
tomo_batch_size: int = 15,
|
|
26
|
+
trainRunIDs: List[str] = None,
|
|
27
|
+
validateRunIDs: List[str] = None,
|
|
28
|
+
mlflow_experiment_name: str = 'explore',
|
|
29
29
|
):
|
|
30
30
|
"""
|
|
31
31
|
Initialize the ModelSearch class for architecture search with Optuna.
|
|
@@ -75,7 +75,7 @@ class ModelSearchSubmit:
|
|
|
75
75
|
self.data_generator = None
|
|
76
76
|
|
|
77
77
|
# Set random seed for reproducibility
|
|
78
|
-
|
|
78
|
+
config.set_seed(self.random_seed)
|
|
79
79
|
|
|
80
80
|
# Initialize dataset generator
|
|
81
81
|
self._initialize_data_generator()
|
|
@@ -108,7 +108,7 @@ class ModelSearchSubmit:
|
|
|
108
108
|
)
|
|
109
109
|
|
|
110
110
|
# Split datasets into training and validation
|
|
111
|
-
ratios =
|
|
111
|
+
ratios = parsers.parse_data_split(self.data_split)
|
|
112
112
|
self.data_generator.get_data_splits(
|
|
113
113
|
trainRunIDs=self.trainRunIDs,
|
|
114
114
|
validateRunIDs=self.validateRunIDs,
|
|
@@ -134,7 +134,7 @@ class ModelSearchSubmit:
|
|
|
134
134
|
|
|
135
135
|
# Set up MLflow tracking
|
|
136
136
|
try:
|
|
137
|
-
tracking_uri =
|
|
137
|
+
tracking_uri = config.mlflow_setup()
|
|
138
138
|
mlflow.set_tracking_uri(tracking_uri)
|
|
139
139
|
except Exception as e:
|
|
140
140
|
print(f'Failed to set up MLflow tracking: {e}')
|
|
@@ -207,7 +207,7 @@ class ModelSearchSubmit:
|
|
|
207
207
|
# Run multi-GPU optimization
|
|
208
208
|
study = self.get_optuna_study()
|
|
209
209
|
study.optimize(
|
|
210
|
-
lambda trial: BayesianModelSearch(self.data_generator, self.model_type).multi_gpu_objective(
|
|
210
|
+
lambda trial: hyper_search.BayesianModelSearch(self.data_generator, self.model_type).multi_gpu_objective(
|
|
211
211
|
parent_run, trial,
|
|
212
212
|
self.num_epochs,
|
|
213
213
|
best_metric=self.best_metric,
|