octopi 1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of octopi might be problematic. Click here for more details.
- octopi/__init__.py +1 -0
- octopi/datasets/cached_datset.py +1 -1
- octopi/datasets/generators.py +1 -1
- octopi/datasets/io.py +200 -0
- octopi/datasets/multi_config_generator.py +1 -1
- octopi/entry_points/common.py +9 -9
- octopi/entry_points/create_slurm_submission.py +16 -8
- octopi/entry_points/run_create_targets.py +6 -6
- octopi/entry_points/run_evaluate.py +4 -3
- octopi/entry_points/run_extract_mb_picks.py +22 -45
- octopi/entry_points/run_localize.py +37 -54
- octopi/entry_points/run_optuna.py +7 -7
- octopi/entry_points/run_segment_predict.py +4 -4
- octopi/entry_points/run_train.py +7 -8
- octopi/extract/localize.py +19 -12
- octopi/extract/membranebound_extract.py +11 -10
- octopi/extract/midpoint_extract.py +3 -3
- octopi/main.py +1 -1
- octopi/models/common.py +1 -1
- octopi/processing/create_targets_from_picks.py +11 -5
- octopi/processing/downsample.py +6 -10
- octopi/processing/evaluate.py +24 -11
- octopi/processing/importers.py +4 -4
- octopi/pytorch/hyper_search.py +2 -3
- octopi/pytorch/model_search_submitter.py +15 -15
- octopi/pytorch/segmentation.py +147 -192
- octopi/pytorch/segmentation_multigpu.py +162 -0
- octopi/pytorch/trainer.py +9 -3
- octopi/utils/__init__.py +0 -0
- octopi/utils/config.py +57 -0
- octopi/utils/io.py +128 -0
- octopi/{utils.py → utils/parsers.py} +10 -84
- octopi/{stopping_criteria.py → utils/stopping_criteria.py} +3 -3
- octopi/{visualization_tools.py → utils/visualization_tools.py} +4 -4
- octopi/workflows.py +236 -0
- octopi-1.2.0.dist-info/METADATA +120 -0
- octopi-1.2.0.dist-info/RECORD +62 -0
- {octopi-1.0.dist-info → octopi-1.2.0.dist-info}/WHEEL +1 -1
- octopi-1.2.0.dist-info/entry_points.txt +3 -0
- {octopi-1.0.dist-info → octopi-1.2.0.dist-info/licenses}/LICENSE +3 -3
- octopi/io.py +0 -457
- octopi/processing/my_metrics.py +0 -26
- octopi/processing/writers.py +0 -102
- octopi-1.0.dist-info/METADATA +0 -209
- octopi-1.0.dist-info/RECORD +0 -59
- octopi-1.0.dist-info/entry_points.txt +0 -4
- /octopi/{losses.py → utils/losses.py} +0 -0
- /octopi/{submit_slurm.py → utils/submit_slurm.py} +0 -0
octopi/utils/config.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Configuration utilities for MLflow setup and reproducibility.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from dotenv import load_dotenv
|
|
6
|
+
import torch, numpy as np
|
|
7
|
+
import os, random
|
|
8
|
+
import octopi
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def mlflow_setup():
|
|
12
|
+
"""
|
|
13
|
+
Set up MLflow configuration from environment variables.
|
|
14
|
+
"""
|
|
15
|
+
module_root = os.path.dirname(octopi.__file__)
|
|
16
|
+
dotenv_path = module_root.replace('src/octopi','') + '.env'
|
|
17
|
+
load_dotenv(dotenv_path=dotenv_path)
|
|
18
|
+
|
|
19
|
+
# MLflow setup
|
|
20
|
+
username = os.getenv('MLFLOW_TRACKING_USERNAME')
|
|
21
|
+
password = os.getenv('MLFLOW_TRACKING_PASSWORD')
|
|
22
|
+
if not password or not username:
|
|
23
|
+
print("Password not found in environment, loading from .env file...")
|
|
24
|
+
load_dotenv() # Loads environment variables from a .env file
|
|
25
|
+
username = os.getenv('MLFLOW_TRACKING_USERNAME')
|
|
26
|
+
password = os.getenv('MLFLOW_TRACKING_PASSWORD')
|
|
27
|
+
|
|
28
|
+
# Check again after loading .env file
|
|
29
|
+
if not password:
|
|
30
|
+
raise ValueError("Password is not set in environment variables or .env file!")
|
|
31
|
+
else:
|
|
32
|
+
print("Password loaded successfully")
|
|
33
|
+
os.environ['MLFLOW_TRACKING_USERNAME'] = username
|
|
34
|
+
os.environ['MLFLOW_TRACKING_PASSWORD'] = password
|
|
35
|
+
|
|
36
|
+
return os.getenv('MLFLOW_TRACKING_URI')
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def set_seed(seed):
|
|
40
|
+
"""
|
|
41
|
+
Set random seeds for reproducibility across Python, NumPy, and PyTorch.
|
|
42
|
+
"""
|
|
43
|
+
# Set the seed for Python's random module
|
|
44
|
+
random.seed(seed)
|
|
45
|
+
|
|
46
|
+
# Set the seed for NumPy
|
|
47
|
+
np.random.seed(seed)
|
|
48
|
+
|
|
49
|
+
# Set the seed for PyTorch (both CPU and GPU)
|
|
50
|
+
torch.manual_seed(seed)
|
|
51
|
+
if torch.cuda.is_available():
|
|
52
|
+
torch.cuda.manual_seed(seed)
|
|
53
|
+
torch.cuda.manual_seed_all(seed) # If using multi-GPU
|
|
54
|
+
|
|
55
|
+
# Ensure reproducibility of operations by disabling certain optimizations
|
|
56
|
+
torch.backends.cudnn.deterministic = True
|
|
57
|
+
torch.backends.cudnn.benchmark = False
|
octopi/utils/io.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
"""
|
|
2
|
+
File I/O utilities for YAML and JSON operations.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os, json, yaml
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
# Create a custom dumper that uses flow style for lists only.
|
|
9
|
+
class InlineListDumper(yaml.SafeDumper):
|
|
10
|
+
def represent_list(self, data):
|
|
11
|
+
node = super().represent_list(data)
|
|
12
|
+
node.flow_style = True # Use inline style for lists
|
|
13
|
+
return node
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def save_parameters_yaml(params: dict, output_path: str):
|
|
17
|
+
"""
|
|
18
|
+
Save parameters to a YAML file.
|
|
19
|
+
"""
|
|
20
|
+
InlineListDumper.add_representer(list, InlineListDumper.represent_list)
|
|
21
|
+
with open(output_path, 'w') as f:
|
|
22
|
+
yaml.dump(params, f, Dumper=InlineListDumper, default_flow_style=False, sort_keys=False)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def load_yaml(path: str) -> dict:
|
|
26
|
+
"""
|
|
27
|
+
Load a YAML file and return the contents as a dictionary.
|
|
28
|
+
"""
|
|
29
|
+
if os.path.exists(path):
|
|
30
|
+
with open(path, 'r') as f:
|
|
31
|
+
return yaml.safe_load(f)
|
|
32
|
+
else:
|
|
33
|
+
raise FileNotFoundError(f"File not found: {path}")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def save_results_to_json(results, filename: str):
|
|
37
|
+
"""
|
|
38
|
+
Save training results to a JSON file.
|
|
39
|
+
"""
|
|
40
|
+
results = prepare_inline_results_json(results)
|
|
41
|
+
with open(os.path.join(filename), "w") as json_file:
|
|
42
|
+
json.dump( results, json_file, indent=4 )
|
|
43
|
+
print(f"Training Results saved to {filename}")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def prepare_inline_results_json(results):
|
|
47
|
+
"""
|
|
48
|
+
Prepare results for inline JSON formatting.
|
|
49
|
+
"""
|
|
50
|
+
# Traverse the dictionary and format lists of lists as inline JSON
|
|
51
|
+
for key, value in results.items():
|
|
52
|
+
# Check if the value is a list of lists (like [[epoch, value], ...])
|
|
53
|
+
if isinstance(value, list) and all(isinstance(item, list) and len(item) == 2 for item in value):
|
|
54
|
+
# Format the list of lists as a single-line JSON string
|
|
55
|
+
results[key] = json.dumps(value)
|
|
56
|
+
return results
|
|
57
|
+
|
|
58
|
+
def get_optimizer_parameters(trainer):
|
|
59
|
+
"""
|
|
60
|
+
Extract optimizer parameters from a trainer object.
|
|
61
|
+
"""
|
|
62
|
+
optimizer_parameters = {
|
|
63
|
+
'my_num_samples': trainer.num_samples,
|
|
64
|
+
'val_interval': trainer.val_interval,
|
|
65
|
+
'lr': trainer.optimizer.param_groups[0]['lr'],
|
|
66
|
+
'optimizer': trainer.optimizer.__class__.__name__,
|
|
67
|
+
'metrics_function': trainer.metrics_function.__class__.__name__,
|
|
68
|
+
'loss_function': trainer.loss_function.__class__.__name__,
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
# Log Tversky Loss Parameters
|
|
72
|
+
if trainer.loss_function.__class__.__name__ == 'TverskyLoss':
|
|
73
|
+
optimizer_parameters['alpha'] = trainer.loss_function.alpha
|
|
74
|
+
elif trainer.loss_function.__class__.__name__ == 'FocalLoss':
|
|
75
|
+
optimizer_parameters['gamma'] = trainer.loss_function.gamma
|
|
76
|
+
elif trainer.loss_function.__class__.__name__ == 'WeightedFocalTverskyLoss':
|
|
77
|
+
optimizer_parameters['alpha'] = trainer.loss_function.alpha
|
|
78
|
+
optimizer_parameters['gamma'] = trainer.loss_function.gamma
|
|
79
|
+
optimizer_parameters['weight_tversky'] = trainer.loss_function.weight_tversky
|
|
80
|
+
elif trainer.loss_function.__class__.__name__ == 'FocalTverskyLoss':
|
|
81
|
+
optimizer_parameters['alpha'] = trainer.loss_function.alpha
|
|
82
|
+
optimizer_parameters['gamma'] = trainer.loss_function.gamma
|
|
83
|
+
|
|
84
|
+
return optimizer_parameters
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def save_parameters_to_yaml(model, trainer, dataloader, filename: str):
|
|
88
|
+
"""
|
|
89
|
+
Save training parameters to a YAML file.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
parameters = {
|
|
93
|
+
'model': model.get_model_parameters(),
|
|
94
|
+
'optimizer': get_optimizer_parameters(trainer),
|
|
95
|
+
'dataloader': dataloader.get_dataloader_parameters()
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
save_parameters_yaml(parameters, filename)
|
|
99
|
+
print(f"Training Parameters saved to {filename}")
|
|
100
|
+
|
|
101
|
+
def flatten_params(params, parent_key=''):
|
|
102
|
+
"""
|
|
103
|
+
Helper function to flatten and serialize nested parameters.
|
|
104
|
+
"""
|
|
105
|
+
flattened = {}
|
|
106
|
+
for key, value in params.items():
|
|
107
|
+
new_key = f"{parent_key}.{key}" if parent_key else key
|
|
108
|
+
if isinstance(value, dict):
|
|
109
|
+
flattened.update(flatten_params(value, new_key))
|
|
110
|
+
elif isinstance(value, list):
|
|
111
|
+
flattened[new_key] = ', '.join(map(str, value)) # Convert list to a comma-separated string
|
|
112
|
+
else:
|
|
113
|
+
flattened[new_key] = value
|
|
114
|
+
return flattened
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def prepare_for_inline_json(data):
|
|
118
|
+
"""
|
|
119
|
+
Manually join specific lists into strings for inline display.
|
|
120
|
+
"""
|
|
121
|
+
for key in ["trainRunIDs", "valRunIDs", "testRunIDs"]:
|
|
122
|
+
if key in data['dataloader']:
|
|
123
|
+
data['dataloader'][key] = f"[{', '.join(map(repr, data['dataloader'][key]))}]"
|
|
124
|
+
|
|
125
|
+
for key in ['channels', 'strides']:
|
|
126
|
+
if key in data['model']:
|
|
127
|
+
data['model'][key] = f"[{', '.join(map(repr, data['model'][key]))}]"
|
|
128
|
+
return data
|
|
@@ -1,58 +1,12 @@
|
|
|
1
|
-
|
|
1
|
+
"""
|
|
2
|
+
Argument parsing and configuration utilities.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import argparse, os, random
|
|
6
|
+
import torch, numpy as np
|
|
2
7
|
from typing import List, Tuple, Union
|
|
3
8
|
from dotenv import load_dotenv
|
|
4
|
-
import
|
|
5
|
-
import torch, random, os, yaml
|
|
6
|
-
from typing import List
|
|
7
|
-
import numpy as np
|
|
8
|
-
|
|
9
|
-
##############################################################################################################################
|
|
10
|
-
|
|
11
|
-
def mlflow_setup():
|
|
12
|
-
|
|
13
|
-
module_root = os.path.dirname(octopi.__file__)
|
|
14
|
-
dotenv_path = module_root.replace('src/octopi','') + '.env'
|
|
15
|
-
load_dotenv(dotenv_path=dotenv_path)
|
|
16
|
-
|
|
17
|
-
# MLflow setup
|
|
18
|
-
username = os.getenv('MLFLOW_TRACKING_USERNAME')
|
|
19
|
-
password = os.getenv('MLFLOW_TRACKING_PASSWORD')
|
|
20
|
-
if not password or not username:
|
|
21
|
-
print("Password not found in environment, loading from .env file...")
|
|
22
|
-
load_dotenv() # Loads environment variables from a .env file
|
|
23
|
-
username = os.getenv('MLFLOW_TRACKING_USERNAME')
|
|
24
|
-
password = os.getenv('MLFLOW_TRACKING_PASSWORD')
|
|
25
|
-
|
|
26
|
-
# Check again after loading .env file
|
|
27
|
-
if not password:
|
|
28
|
-
raise ValueError("Password is not set in environment variables or .env file!")
|
|
29
|
-
else:
|
|
30
|
-
print("Password loaded successfully")
|
|
31
|
-
os.environ['MLFLOW_TRACKING_USERNAME'] = username
|
|
32
|
-
os.environ['MLFLOW_TRACKING_PASSWORD'] = password
|
|
33
|
-
|
|
34
|
-
return os.getenv('MLFLOW_TRACKING_URI')
|
|
35
|
-
|
|
36
|
-
##############################################################################################################################
|
|
37
|
-
|
|
38
|
-
def set_seed(seed):
|
|
39
|
-
# Set the seed for Python's random module
|
|
40
|
-
random.seed(seed)
|
|
41
|
-
|
|
42
|
-
# Set the seed for NumPy
|
|
43
|
-
np.random.seed(seed)
|
|
44
|
-
|
|
45
|
-
# Set the seed for PyTorch (both CPU and GPU)
|
|
46
|
-
torch.manual_seed(seed)
|
|
47
|
-
if torch.cuda.is_available():
|
|
48
|
-
torch.cuda.manual_seed(seed)
|
|
49
|
-
torch.cuda.manual_seed_all(seed) # If using multi-GPU
|
|
50
|
-
|
|
51
|
-
# Ensure reproducibility of operations by disabling certain optimizations
|
|
52
|
-
torch.backends.cudnn.deterministic = True
|
|
53
|
-
torch.backends.cudnn.benchmark = False
|
|
54
|
-
|
|
55
|
-
###############################################################################################################################
|
|
9
|
+
import octopi
|
|
56
10
|
|
|
57
11
|
def parse_list(value: str) -> List[str]:
|
|
58
12
|
"""
|
|
@@ -62,7 +16,6 @@ def parse_list(value: str) -> List[str]:
|
|
|
62
16
|
value = value.strip("[]") # Remove brackets if present
|
|
63
17
|
return [x.strip() for x in value.split(",")]
|
|
64
18
|
|
|
65
|
-
###############################################################################################################################
|
|
66
19
|
|
|
67
20
|
def parse_int_list(value: str) -> List[int]:
|
|
68
21
|
"""
|
|
@@ -71,7 +24,6 @@ def parse_int_list(value: str) -> List[int]:
|
|
|
71
24
|
"""
|
|
72
25
|
return [int(x) for x in parse_list(value)]
|
|
73
26
|
|
|
74
|
-
###############################################################################################################################
|
|
75
27
|
|
|
76
28
|
def string2bool(value: str):
|
|
77
29
|
"""
|
|
@@ -86,7 +38,6 @@ def string2bool(value: str):
|
|
|
86
38
|
else:
|
|
87
39
|
raise argparse.ArgumentTypeError(f"Invalid boolean value: {value}")
|
|
88
40
|
|
|
89
|
-
###############################################################################################################################
|
|
90
41
|
|
|
91
42
|
def parse_target(value: str) -> Tuple[str, Union[str, None], Union[str, None]]:
|
|
92
43
|
"""
|
|
@@ -130,6 +81,7 @@ def parse_seg_target(value: str) -> List[Tuple[str, Union[str, None], Union[str,
|
|
|
130
81
|
)
|
|
131
82
|
return targets
|
|
132
83
|
|
|
84
|
+
|
|
133
85
|
def parse_copick_configs(config_entries: List[str]):
|
|
134
86
|
"""
|
|
135
87
|
Parse a string representing a list of CoPick configuration file paths.
|
|
@@ -172,6 +124,7 @@ def parse_copick_configs(config_entries: List[str]):
|
|
|
172
124
|
|
|
173
125
|
return copick_configs
|
|
174
126
|
|
|
127
|
+
|
|
175
128
|
def parse_data_split(value: str) -> Tuple[float, float, float]:
|
|
176
129
|
"""
|
|
177
130
|
Parse data split ratios from string input.
|
|
@@ -208,31 +161,4 @@ def parse_data_split(value: str) -> Tuple[float, float, float]:
|
|
|
208
161
|
if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6:
|
|
209
162
|
raise ValueError(f"Ratios must sum to 1.0, got {train_ratio + val_ratio + test_ratio}")
|
|
210
163
|
|
|
211
|
-
return round(train_ratio, 2), round(val_ratio, 2), round(test_ratio, 2)
|
|
212
|
-
|
|
213
|
-
##############################################################################################################################
|
|
214
|
-
|
|
215
|
-
# Create a custom dumper that uses flow style for lists only.
|
|
216
|
-
class InlineListDumper(yaml.SafeDumper):
|
|
217
|
-
def represent_list(self, data):
|
|
218
|
-
node = super().represent_list(data)
|
|
219
|
-
node.flow_style = True # Use inline style for lists
|
|
220
|
-
return node
|
|
221
|
-
|
|
222
|
-
def save_parameters_yaml(params: dict, output_path: str):
|
|
223
|
-
"""
|
|
224
|
-
Save parameters to a YAML file.
|
|
225
|
-
"""
|
|
226
|
-
InlineListDumper.add_representer(list, InlineListDumper.represent_list)
|
|
227
|
-
with open(output_path, 'w') as f:
|
|
228
|
-
yaml.dump(params, f, Dumper=InlineListDumper, default_flow_style=False, sort_keys=False)
|
|
229
|
-
|
|
230
|
-
def load_yaml(path: str) -> dict:
|
|
231
|
-
"""
|
|
232
|
-
Load a YAML file and return the contents as a dictionary.
|
|
233
|
-
"""
|
|
234
|
-
if os.path.exists(path):
|
|
235
|
-
with open(path, 'r') as f:
|
|
236
|
-
return yaml.safe_load(f)
|
|
237
|
-
else:
|
|
238
|
-
raise FileNotFoundError(f"File not found: {path}")
|
|
164
|
+
return round(train_ratio, 2), round(val_ratio, 2), round(test_ratio, 2)
|
|
@@ -7,9 +7,9 @@ class EarlyStoppingChecker:
|
|
|
7
7
|
|
|
8
8
|
def __init__(self,
|
|
9
9
|
max_nan_epochs=15,
|
|
10
|
-
plateau_patience=
|
|
11
|
-
plateau_min_delta=0.
|
|
12
|
-
stagnation_patience=
|
|
10
|
+
plateau_patience=10,
|
|
11
|
+
plateau_min_delta=0.005,
|
|
12
|
+
stagnation_patience=30,
|
|
13
13
|
convergence_window=5,
|
|
14
14
|
convergence_threshold=0.005,
|
|
15
15
|
val_interval=15,
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
+
from copick_utils.io import readers
|
|
1
2
|
import matplotlib.colors as mcolors
|
|
2
3
|
from typing import Optional, List
|
|
3
4
|
import matplotlib.pyplot as plt
|
|
4
|
-
from octopi import io
|
|
5
5
|
import numpy as np
|
|
6
6
|
|
|
7
7
|
# Define the plotting function
|
|
@@ -76,7 +76,7 @@ def show_tomo_points(tomo, run, objects, user_id, vol_slice, session_id = None,
|
|
|
76
76
|
|
|
77
77
|
for name,_,_ in objects:
|
|
78
78
|
try:
|
|
79
|
-
coordinates =
|
|
79
|
+
coordinates = readers.coordinates(run, name=name, user_id=user_id, session_id=session_id)
|
|
80
80
|
close_points = coordinates[np.abs(coordinates[:, 0] - vol_slice) <= slice_proximity_threshold]
|
|
81
81
|
plt.scatter(close_points[:, 2], close_points[:, 1], label=name, s=15)
|
|
82
82
|
except:
|
|
@@ -94,7 +94,7 @@ def compare_tomo_points(tomo, run, objects, vol_slice, user_id1, user_id2,
|
|
|
94
94
|
|
|
95
95
|
for name,_,_ in objects:
|
|
96
96
|
try:
|
|
97
|
-
coordinates =
|
|
97
|
+
coordinates = readers.coordinates(run, name=name, user_id=user_id1, session_id=session_id1)
|
|
98
98
|
close_points = coordinates[np.abs(coordinates[:, 0] - vol_slice) <= slice_proximity_threshold]
|
|
99
99
|
plt.scatter(close_points[:, 2], close_points[:, 1], label=name, s=15)
|
|
100
100
|
except:
|
|
@@ -106,7 +106,7 @@ def compare_tomo_points(tomo, run, objects, vol_slice, user_id1, user_id2,
|
|
|
106
106
|
|
|
107
107
|
for name,_,_ in objects:
|
|
108
108
|
try:
|
|
109
|
-
coordinates =
|
|
109
|
+
coordinates = readers.coordinates(run, name=name, user_id=user_id2, session_id=session_id2)
|
|
110
110
|
close_points = coordinates[np.abs(coordinates[:, 0] - vol_slice) <= slice_proximity_threshold]
|
|
111
111
|
plt.scatter(close_points[:, 2], close_points[:, 1], label=name, s=15)
|
|
112
112
|
except:
|
octopi/workflows.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
from octopi.extract.localize import process_localization
|
|
2
|
+
import octopi.processing.evaluate as octopi_evaluate
|
|
3
|
+
from monai.metrics import ConfusionMatrixMetric
|
|
4
|
+
from octopi.models import common as builder
|
|
5
|
+
from octopi.pytorch import segmentation
|
|
6
|
+
from octopi.datasets import generators
|
|
7
|
+
from octopi.pytorch import trainer
|
|
8
|
+
import multiprocess as mp
|
|
9
|
+
import copick, torch, os
|
|
10
|
+
from octopi.utils import io
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
|
|
13
|
+
def train(config, target_info, tomo_algorithm, voxel_size, loss_function,
|
|
14
|
+
model_config = None, model_weights = None, trainRunIDs = None, validateRunIDs = None,
|
|
15
|
+
model_save_path = 'results', best_metric = 'fBeta2', num_epochs = 1000, use_ema = True):
|
|
16
|
+
"""
|
|
17
|
+
Train a UNet Model for Segmentation
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
config (str): Path to the Copick Config File
|
|
21
|
+
target_info (list): List containing the target user ID, target session ID, and target algorithm
|
|
22
|
+
tomo_algorithm (str): The tomographic algorithm to use for segmentation
|
|
23
|
+
voxel_size (float): The voxel size of the data
|
|
24
|
+
loss_function (str): The loss function to use for training
|
|
25
|
+
model_config (dict): The model configuration
|
|
26
|
+
model_weights (str): The path to the model weights
|
|
27
|
+
trainRunIDs (list): The list of run IDs to use for training
|
|
28
|
+
validateRunIDs (list): The list of run IDs to use for validation
|
|
29
|
+
model_save_path (str): The path to save the model
|
|
30
|
+
best_metric (str): The metric to use for early stopping
|
|
31
|
+
num_epochs (int): The number of epochs to train for
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
# If No Model Configuration is Provided, Use the Default Configuration
|
|
35
|
+
if model_config is None:
|
|
36
|
+
root = copick.from_file(config)
|
|
37
|
+
model_config = {
|
|
38
|
+
'architecture': 'Unet',
|
|
39
|
+
'num_classes': root.pickable_objects[-1].label + 1,
|
|
40
|
+
'dim_in': 80,
|
|
41
|
+
'strides': [2, 2, 1],
|
|
42
|
+
'channels': [48, 64, 80, 80],
|
|
43
|
+
'dropout': 0.0, 'num_res_units': 1,
|
|
44
|
+
}
|
|
45
|
+
print('No Model Configuration Provided, Using Default Configuration')
|
|
46
|
+
print(model_config)
|
|
47
|
+
|
|
48
|
+
data_generator = generators.TrainLoaderManager(
|
|
49
|
+
config,
|
|
50
|
+
target_info[0],
|
|
51
|
+
target_session_id = target_info[2],
|
|
52
|
+
target_user_id = target_info[1],
|
|
53
|
+
tomo_algorithm = tomo_algorithm,
|
|
54
|
+
voxel_size = voxel_size,
|
|
55
|
+
Nclasses = model_config['num_classes'],
|
|
56
|
+
tomo_batch_size = 15 )
|
|
57
|
+
|
|
58
|
+
data_generator.get_data_splits(
|
|
59
|
+
trainRunIDs = trainRunIDs,
|
|
60
|
+
validateRunIDs = validateRunIDs,
|
|
61
|
+
train_ratio = 0.9, val_ratio = 0.1, test_ratio = 0.0,
|
|
62
|
+
create_test_dataset = False)
|
|
63
|
+
|
|
64
|
+
# Get the reload frequency
|
|
65
|
+
data_generator.get_reload_frequency(num_epochs)
|
|
66
|
+
|
|
67
|
+
# Monai Functions
|
|
68
|
+
metrics_function = ConfusionMatrixMetric(include_background=False, metric_name=["recall",'precision','f1 score'], reduction="none")
|
|
69
|
+
|
|
70
|
+
# Build the Model
|
|
71
|
+
model_builder = builder.get_model(model_config['architecture'])
|
|
72
|
+
model = model_builder.build_model(model_config)
|
|
73
|
+
|
|
74
|
+
# Load the Model Weights if Provided
|
|
75
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
76
|
+
if model_weights:
|
|
77
|
+
state_dict = torch.load(model_weights, map_location=device, weights_only=True)
|
|
78
|
+
model.load_state_dict(state_dict)
|
|
79
|
+
model.to(device)
|
|
80
|
+
|
|
81
|
+
# Optimizer
|
|
82
|
+
optimizer = torch.optim.Adam(model.parameters(), 1e-3)
|
|
83
|
+
|
|
84
|
+
# Create UNet-Trainer
|
|
85
|
+
model_trainer = trainer.ModelTrainer(
|
|
86
|
+
model, device, loss_function, metrics_function, optimizer,
|
|
87
|
+
use_ema = use_ema
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
results = model_trainer.train(
|
|
91
|
+
data_generator, model_save_path, max_epochs=num_epochs,
|
|
92
|
+
crop_size=model_config['dim_in'], my_num_samples=16,
|
|
93
|
+
val_interval=10, best_metric=best_metric, verbose=True
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Save parameters and results
|
|
97
|
+
parameters_save_name = os.path.join(model_save_path, "model_config.yaml")
|
|
98
|
+
io.save_parameters_to_yaml(model_builder, model_trainer, data_generator, parameters_save_name)
|
|
99
|
+
|
|
100
|
+
# TODO: Write Results to Zarr or Another File Format?
|
|
101
|
+
results_save_name = os.path.join(model_save_path, "results.json")
|
|
102
|
+
io.save_results_to_json(results, results_save_name)
|
|
103
|
+
|
|
104
|
+
def segment(config, tomo_algorithm, voxel_size, model_weights, model_config,
|
|
105
|
+
seg_info = ['predict', 'octopi', '1'], use_tta = False, run_ids = None):
|
|
106
|
+
"""
|
|
107
|
+
Segment a Dataset using a Trained Model or Ensemble of Models
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
config (str): Path to the Copick Config File
|
|
111
|
+
tomo_algorithm (str): The tomographic algorithm to use for segmentation
|
|
112
|
+
voxel_size (float): The voxel size of the data
|
|
113
|
+
model_weights (str, list): The path to the model weights or a list of paths to the model weights
|
|
114
|
+
model_config (str, list): The model configuration or a list of model configurations
|
|
115
|
+
seg_info (list): The segmentation information
|
|
116
|
+
use_tta (bool): Whether to use test time augmentation
|
|
117
|
+
run_ids (list): The list of run IDs to use for segmentation
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
# Initialize the Predictor
|
|
121
|
+
predict = segmentation.Predictor(
|
|
122
|
+
config,
|
|
123
|
+
model_config,
|
|
124
|
+
model_weights,
|
|
125
|
+
apply_tta = use_tta
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# Run batch prediction
|
|
129
|
+
predict.batch_predict(
|
|
130
|
+
runIDs=run_ids,
|
|
131
|
+
num_tomos_per_batch=15,
|
|
132
|
+
tomo_algorithm=tomo_algorithm,
|
|
133
|
+
voxel_spacing=voxel_size,
|
|
134
|
+
segmentation_name=seg_info[0],
|
|
135
|
+
segmentation_user_id=seg_info[1],
|
|
136
|
+
segmentation_session_id=seg_info[2]
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
def localize(config, voxel_size, seg_info, pick_user_id, pick_session_id, n_procs = 16,
|
|
140
|
+
method = 'watershed', filter_size = 10, radius_min_scale = 0.4, radius_max_scale = 1.0,
|
|
141
|
+
run_ids = None):
|
|
142
|
+
"""
|
|
143
|
+
Extract 3D Coordinates from the Segmentation Maps
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
config (str): Path to the Copick Config File
|
|
147
|
+
voxel_size (float): The voxel size of the data
|
|
148
|
+
seg_info (list): The segmentation information
|
|
149
|
+
pick_user_id (str): The user ID of the pick
|
|
150
|
+
pick_session_id (str): The session ID of the pick
|
|
151
|
+
n_procs (int): The number of processes to use for parallelization
|
|
152
|
+
method (str): The method to use for localization
|
|
153
|
+
filter_size (int): The filter size to use for localization
|
|
154
|
+
radius_min_scale (float): The minimum radius scale to use for localization
|
|
155
|
+
radius_max_scale (float): The maximum radius scale to use for localization
|
|
156
|
+
run_ids (list): The list of run IDs to use for localization
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
# Load the Copick Config
|
|
160
|
+
root = copick.from_file(config)
|
|
161
|
+
|
|
162
|
+
# Get objects that can be Picked
|
|
163
|
+
objects = [(obj.name, obj.label, obj.radius) for obj in root.pickable_objects if obj.is_particle]
|
|
164
|
+
|
|
165
|
+
# Get all RunIDs
|
|
166
|
+
if run_ids is None:
|
|
167
|
+
run_ids = [run.name for run in root.runs]
|
|
168
|
+
n_run_ids = len(run_ids)
|
|
169
|
+
|
|
170
|
+
# Run Localization - Main Parallelization Loop
|
|
171
|
+
print(f"Using {n_procs} processes to parallelize across {n_run_ids} run IDs.")
|
|
172
|
+
with mp.Pool(processes=n_procs) as pool:
|
|
173
|
+
with tqdm(total=n_run_ids, desc="Localization", unit="run") as pbar:
|
|
174
|
+
worker_func = lambda run_id: process_localization(
|
|
175
|
+
root.get_run(run_id),
|
|
176
|
+
objects,
|
|
177
|
+
seg_info,
|
|
178
|
+
method,
|
|
179
|
+
voxel_size,
|
|
180
|
+
filter_size,
|
|
181
|
+
radius_min_scale,
|
|
182
|
+
radius_max_scale,
|
|
183
|
+
pick_session_id,
|
|
184
|
+
pick_user_id
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
for _ in pool.imap_unordered(worker_func, run_ids, chunksize=1):
|
|
188
|
+
pbar.update(1)
|
|
189
|
+
|
|
190
|
+
print('Localization Complete!')
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def evaluate(config,
|
|
194
|
+
gt_user_id, gt_session_id,
|
|
195
|
+
pred_user_id, pred_session_id,
|
|
196
|
+
run_ids = None, distance_threshold = 0.5, save_path = None):
|
|
197
|
+
"""
|
|
198
|
+
Evaluate the Localization on a Dataset
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
config (str): Path to the Copick Config File
|
|
202
|
+
gt_user_id (str): The user ID of the ground truth
|
|
203
|
+
gt_session_id (str): The session ID of the ground truth
|
|
204
|
+
pred_user_id (str): The user ID of the predicted coordinates
|
|
205
|
+
pred_session_id (str): The session ID of the predicted coordinates
|
|
206
|
+
run_ids (list): The list of run IDs to use for evaluation
|
|
207
|
+
distance_threshold (float): The distance threshold to use for evaluation
|
|
208
|
+
save_path (str): The path to save the evaluation results
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
print('Running Evaluation on the Following Query:')
|
|
212
|
+
print(f'Distance Threshold: {distance_threshold}')
|
|
213
|
+
print(f'GT User ID: {gt_user_id}, GT Session ID: {gt_session_id}')
|
|
214
|
+
print(f'Pred User ID: {pred_user_id}, Pred Session ID: {pred_session_id}')
|
|
215
|
+
print(f'Run IDs: {run_ids}')
|
|
216
|
+
|
|
217
|
+
# Load the Copick Config
|
|
218
|
+
root = copick.from_file(config)
|
|
219
|
+
|
|
220
|
+
# For Now Lets Assume Object Names are None..
|
|
221
|
+
object_names = None
|
|
222
|
+
|
|
223
|
+
# Run Evaluation
|
|
224
|
+
eval = octopi_evaluate.evaluator(
|
|
225
|
+
config,
|
|
226
|
+
gt_user_id,
|
|
227
|
+
gt_session_id,
|
|
228
|
+
pred_user_id,
|
|
229
|
+
pred_session_id,
|
|
230
|
+
object_names=object_names
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
eval.run(
|
|
234
|
+
distance_threshold_scale=distance_threshold,
|
|
235
|
+
runIDs=run_ids, save_path=save_path
|
|
236
|
+
)
|