octopi 1.1__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of octopi might be problematic. Click here for more details.
- octopi/__init__.py +1 -0
- octopi/datasets/cached_datset.py +1 -1
- octopi/datasets/generators.py +1 -1
- octopi/datasets/io.py +200 -0
- octopi/datasets/multi_config_generator.py +1 -1
- octopi/entry_points/common.py +5 -5
- octopi/entry_points/create_slurm_submission.py +1 -1
- octopi/entry_points/run_create_targets.py +6 -6
- octopi/entry_points/run_evaluate.py +4 -3
- octopi/entry_points/run_extract_mb_picks.py +5 -5
- octopi/entry_points/run_localize.py +8 -9
- octopi/entry_points/run_optuna.py +7 -7
- octopi/entry_points/run_segment_predict.py +4 -4
- octopi/entry_points/run_train.py +7 -8
- octopi/extract/localize.py +11 -19
- octopi/extract/membranebound_extract.py +11 -10
- octopi/extract/midpoint_extract.py +3 -3
- octopi/models/common.py +1 -1
- octopi/processing/create_targets_from_picks.py +3 -4
- octopi/processing/evaluate.py +24 -11
- octopi/processing/importers.py +4 -4
- octopi/pytorch/hyper_search.py +2 -3
- octopi/pytorch/model_search_submitter.py +4 -4
- octopi/pytorch/segmentation.py +141 -190
- octopi/pytorch/segmentation_multigpu.py +162 -0
- octopi/pytorch/trainer.py +2 -2
- octopi/utils/__init__.py +0 -0
- octopi/utils/config.py +57 -0
- octopi/utils/io.py +128 -0
- octopi/{utils.py → utils/parsers.py} +10 -84
- octopi/{stopping_criteria.py → utils/stopping_criteria.py} +3 -3
- octopi/{visualization_tools.py → utils/visualization_tools.py} +4 -4
- octopi/workflows.py +236 -0
- {octopi-1.1.dist-info → octopi-1.2.0.dist-info}/METADATA +41 -29
- octopi-1.2.0.dist-info/RECORD +62 -0
- {octopi-1.1.dist-info → octopi-1.2.0.dist-info}/WHEEL +1 -1
- octopi-1.2.0.dist-info/entry_points.txt +3 -0
- {octopi-1.1.dist-info → octopi-1.2.0.dist-info/licenses}/LICENSE +3 -3
- octopi/io.py +0 -457
- octopi/processing/my_metrics.py +0 -26
- octopi/processing/writers.py +0 -102
- octopi-1.1.dist-info/RECORD +0 -59
- octopi-1.1.dist-info/entry_points.txt +0 -4
- /octopi/{losses.py → utils/losses.py} +0 -0
- /octopi/{submit_slurm.py → utils/submit_slurm.py} +0 -0
octopi/utils/io.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
"""
|
|
2
|
+
File I/O utilities for YAML and JSON operations.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os, json, yaml
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
# Create a custom dumper that uses flow style for lists only.
|
|
9
|
+
class InlineListDumper(yaml.SafeDumper):
|
|
10
|
+
def represent_list(self, data):
|
|
11
|
+
node = super().represent_list(data)
|
|
12
|
+
node.flow_style = True # Use inline style for lists
|
|
13
|
+
return node
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def save_parameters_yaml(params: dict, output_path: str):
|
|
17
|
+
"""
|
|
18
|
+
Save parameters to a YAML file.
|
|
19
|
+
"""
|
|
20
|
+
InlineListDumper.add_representer(list, InlineListDumper.represent_list)
|
|
21
|
+
with open(output_path, 'w') as f:
|
|
22
|
+
yaml.dump(params, f, Dumper=InlineListDumper, default_flow_style=False, sort_keys=False)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def load_yaml(path: str) -> dict:
|
|
26
|
+
"""
|
|
27
|
+
Load a YAML file and return the contents as a dictionary.
|
|
28
|
+
"""
|
|
29
|
+
if os.path.exists(path):
|
|
30
|
+
with open(path, 'r') as f:
|
|
31
|
+
return yaml.safe_load(f)
|
|
32
|
+
else:
|
|
33
|
+
raise FileNotFoundError(f"File not found: {path}")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def save_results_to_json(results, filename: str):
|
|
37
|
+
"""
|
|
38
|
+
Save training results to a JSON file.
|
|
39
|
+
"""
|
|
40
|
+
results = prepare_inline_results_json(results)
|
|
41
|
+
with open(os.path.join(filename), "w") as json_file:
|
|
42
|
+
json.dump( results, json_file, indent=4 )
|
|
43
|
+
print(f"Training Results saved to {filename}")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def prepare_inline_results_json(results):
|
|
47
|
+
"""
|
|
48
|
+
Prepare results for inline JSON formatting.
|
|
49
|
+
"""
|
|
50
|
+
# Traverse the dictionary and format lists of lists as inline JSON
|
|
51
|
+
for key, value in results.items():
|
|
52
|
+
# Check if the value is a list of lists (like [[epoch, value], ...])
|
|
53
|
+
if isinstance(value, list) and all(isinstance(item, list) and len(item) == 2 for item in value):
|
|
54
|
+
# Format the list of lists as a single-line JSON string
|
|
55
|
+
results[key] = json.dumps(value)
|
|
56
|
+
return results
|
|
57
|
+
|
|
58
|
+
def get_optimizer_parameters(trainer):
|
|
59
|
+
"""
|
|
60
|
+
Extract optimizer parameters from a trainer object.
|
|
61
|
+
"""
|
|
62
|
+
optimizer_parameters = {
|
|
63
|
+
'my_num_samples': trainer.num_samples,
|
|
64
|
+
'val_interval': trainer.val_interval,
|
|
65
|
+
'lr': trainer.optimizer.param_groups[0]['lr'],
|
|
66
|
+
'optimizer': trainer.optimizer.__class__.__name__,
|
|
67
|
+
'metrics_function': trainer.metrics_function.__class__.__name__,
|
|
68
|
+
'loss_function': trainer.loss_function.__class__.__name__,
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
# Log Tversky Loss Parameters
|
|
72
|
+
if trainer.loss_function.__class__.__name__ == 'TverskyLoss':
|
|
73
|
+
optimizer_parameters['alpha'] = trainer.loss_function.alpha
|
|
74
|
+
elif trainer.loss_function.__class__.__name__ == 'FocalLoss':
|
|
75
|
+
optimizer_parameters['gamma'] = trainer.loss_function.gamma
|
|
76
|
+
elif trainer.loss_function.__class__.__name__ == 'WeightedFocalTverskyLoss':
|
|
77
|
+
optimizer_parameters['alpha'] = trainer.loss_function.alpha
|
|
78
|
+
optimizer_parameters['gamma'] = trainer.loss_function.gamma
|
|
79
|
+
optimizer_parameters['weight_tversky'] = trainer.loss_function.weight_tversky
|
|
80
|
+
elif trainer.loss_function.__class__.__name__ == 'FocalTverskyLoss':
|
|
81
|
+
optimizer_parameters['alpha'] = trainer.loss_function.alpha
|
|
82
|
+
optimizer_parameters['gamma'] = trainer.loss_function.gamma
|
|
83
|
+
|
|
84
|
+
return optimizer_parameters
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def save_parameters_to_yaml(model, trainer, dataloader, filename: str):
|
|
88
|
+
"""
|
|
89
|
+
Save training parameters to a YAML file.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
parameters = {
|
|
93
|
+
'model': model.get_model_parameters(),
|
|
94
|
+
'optimizer': get_optimizer_parameters(trainer),
|
|
95
|
+
'dataloader': dataloader.get_dataloader_parameters()
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
save_parameters_yaml(parameters, filename)
|
|
99
|
+
print(f"Training Parameters saved to {filename}")
|
|
100
|
+
|
|
101
|
+
def flatten_params(params, parent_key=''):
|
|
102
|
+
"""
|
|
103
|
+
Helper function to flatten and serialize nested parameters.
|
|
104
|
+
"""
|
|
105
|
+
flattened = {}
|
|
106
|
+
for key, value in params.items():
|
|
107
|
+
new_key = f"{parent_key}.{key}" if parent_key else key
|
|
108
|
+
if isinstance(value, dict):
|
|
109
|
+
flattened.update(flatten_params(value, new_key))
|
|
110
|
+
elif isinstance(value, list):
|
|
111
|
+
flattened[new_key] = ', '.join(map(str, value)) # Convert list to a comma-separated string
|
|
112
|
+
else:
|
|
113
|
+
flattened[new_key] = value
|
|
114
|
+
return flattened
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def prepare_for_inline_json(data):
|
|
118
|
+
"""
|
|
119
|
+
Manually join specific lists into strings for inline display.
|
|
120
|
+
"""
|
|
121
|
+
for key in ["trainRunIDs", "valRunIDs", "testRunIDs"]:
|
|
122
|
+
if key in data['dataloader']:
|
|
123
|
+
data['dataloader'][key] = f"[{', '.join(map(repr, data['dataloader'][key]))}]"
|
|
124
|
+
|
|
125
|
+
for key in ['channels', 'strides']:
|
|
126
|
+
if key in data['model']:
|
|
127
|
+
data['model'][key] = f"[{', '.join(map(repr, data['model'][key]))}]"
|
|
128
|
+
return data
|
|
@@ -1,58 +1,12 @@
|
|
|
1
|
-
|
|
1
|
+
"""
|
|
2
|
+
Argument parsing and configuration utilities.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import argparse, os, random
|
|
6
|
+
import torch, numpy as np
|
|
2
7
|
from typing import List, Tuple, Union
|
|
3
8
|
from dotenv import load_dotenv
|
|
4
|
-
import
|
|
5
|
-
import torch, random, os, yaml
|
|
6
|
-
from typing import List
|
|
7
|
-
import numpy as np
|
|
8
|
-
|
|
9
|
-
##############################################################################################################################
|
|
10
|
-
|
|
11
|
-
def mlflow_setup():
|
|
12
|
-
|
|
13
|
-
module_root = os.path.dirname(octopi.__file__)
|
|
14
|
-
dotenv_path = module_root.replace('src/octopi','') + '.env'
|
|
15
|
-
load_dotenv(dotenv_path=dotenv_path)
|
|
16
|
-
|
|
17
|
-
# MLflow setup
|
|
18
|
-
username = os.getenv('MLFLOW_TRACKING_USERNAME')
|
|
19
|
-
password = os.getenv('MLFLOW_TRACKING_PASSWORD')
|
|
20
|
-
if not password or not username:
|
|
21
|
-
print("Password not found in environment, loading from .env file...")
|
|
22
|
-
load_dotenv() # Loads environment variables from a .env file
|
|
23
|
-
username = os.getenv('MLFLOW_TRACKING_USERNAME')
|
|
24
|
-
password = os.getenv('MLFLOW_TRACKING_PASSWORD')
|
|
25
|
-
|
|
26
|
-
# Check again after loading .env file
|
|
27
|
-
if not password:
|
|
28
|
-
raise ValueError("Password is not set in environment variables or .env file!")
|
|
29
|
-
else:
|
|
30
|
-
print("Password loaded successfully")
|
|
31
|
-
os.environ['MLFLOW_TRACKING_USERNAME'] = username
|
|
32
|
-
os.environ['MLFLOW_TRACKING_PASSWORD'] = password
|
|
33
|
-
|
|
34
|
-
return os.getenv('MLFLOW_TRACKING_URI')
|
|
35
|
-
|
|
36
|
-
##############################################################################################################################
|
|
37
|
-
|
|
38
|
-
def set_seed(seed):
|
|
39
|
-
# Set the seed for Python's random module
|
|
40
|
-
random.seed(seed)
|
|
41
|
-
|
|
42
|
-
# Set the seed for NumPy
|
|
43
|
-
np.random.seed(seed)
|
|
44
|
-
|
|
45
|
-
# Set the seed for PyTorch (both CPU and GPU)
|
|
46
|
-
torch.manual_seed(seed)
|
|
47
|
-
if torch.cuda.is_available():
|
|
48
|
-
torch.cuda.manual_seed(seed)
|
|
49
|
-
torch.cuda.manual_seed_all(seed) # If using multi-GPU
|
|
50
|
-
|
|
51
|
-
# Ensure reproducibility of operations by disabling certain optimizations
|
|
52
|
-
torch.backends.cudnn.deterministic = True
|
|
53
|
-
torch.backends.cudnn.benchmark = False
|
|
54
|
-
|
|
55
|
-
###############################################################################################################################
|
|
9
|
+
import octopi
|
|
56
10
|
|
|
57
11
|
def parse_list(value: str) -> List[str]:
|
|
58
12
|
"""
|
|
@@ -62,7 +16,6 @@ def parse_list(value: str) -> List[str]:
|
|
|
62
16
|
value = value.strip("[]") # Remove brackets if present
|
|
63
17
|
return [x.strip() for x in value.split(",")]
|
|
64
18
|
|
|
65
|
-
###############################################################################################################################
|
|
66
19
|
|
|
67
20
|
def parse_int_list(value: str) -> List[int]:
|
|
68
21
|
"""
|
|
@@ -71,7 +24,6 @@ def parse_int_list(value: str) -> List[int]:
|
|
|
71
24
|
"""
|
|
72
25
|
return [int(x) for x in parse_list(value)]
|
|
73
26
|
|
|
74
|
-
###############################################################################################################################
|
|
75
27
|
|
|
76
28
|
def string2bool(value: str):
|
|
77
29
|
"""
|
|
@@ -86,7 +38,6 @@ def string2bool(value: str):
|
|
|
86
38
|
else:
|
|
87
39
|
raise argparse.ArgumentTypeError(f"Invalid boolean value: {value}")
|
|
88
40
|
|
|
89
|
-
###############################################################################################################################
|
|
90
41
|
|
|
91
42
|
def parse_target(value: str) -> Tuple[str, Union[str, None], Union[str, None]]:
|
|
92
43
|
"""
|
|
@@ -130,6 +81,7 @@ def parse_seg_target(value: str) -> List[Tuple[str, Union[str, None], Union[str,
|
|
|
130
81
|
)
|
|
131
82
|
return targets
|
|
132
83
|
|
|
84
|
+
|
|
133
85
|
def parse_copick_configs(config_entries: List[str]):
|
|
134
86
|
"""
|
|
135
87
|
Parse a string representing a list of CoPick configuration file paths.
|
|
@@ -172,6 +124,7 @@ def parse_copick_configs(config_entries: List[str]):
|
|
|
172
124
|
|
|
173
125
|
return copick_configs
|
|
174
126
|
|
|
127
|
+
|
|
175
128
|
def parse_data_split(value: str) -> Tuple[float, float, float]:
|
|
176
129
|
"""
|
|
177
130
|
Parse data split ratios from string input.
|
|
@@ -208,31 +161,4 @@ def parse_data_split(value: str) -> Tuple[float, float, float]:
|
|
|
208
161
|
if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6:
|
|
209
162
|
raise ValueError(f"Ratios must sum to 1.0, got {train_ratio + val_ratio + test_ratio}")
|
|
210
163
|
|
|
211
|
-
return round(train_ratio, 2), round(val_ratio, 2), round(test_ratio, 2)
|
|
212
|
-
|
|
213
|
-
##############################################################################################################################
|
|
214
|
-
|
|
215
|
-
# Create a custom dumper that uses flow style for lists only.
|
|
216
|
-
class InlineListDumper(yaml.SafeDumper):
|
|
217
|
-
def represent_list(self, data):
|
|
218
|
-
node = super().represent_list(data)
|
|
219
|
-
node.flow_style = True # Use inline style for lists
|
|
220
|
-
return node
|
|
221
|
-
|
|
222
|
-
def save_parameters_yaml(params: dict, output_path: str):
|
|
223
|
-
"""
|
|
224
|
-
Save parameters to a YAML file.
|
|
225
|
-
"""
|
|
226
|
-
InlineListDumper.add_representer(list, InlineListDumper.represent_list)
|
|
227
|
-
with open(output_path, 'w') as f:
|
|
228
|
-
yaml.dump(params, f, Dumper=InlineListDumper, default_flow_style=False, sort_keys=False)
|
|
229
|
-
|
|
230
|
-
def load_yaml(path: str) -> dict:
|
|
231
|
-
"""
|
|
232
|
-
Load a YAML file and return the contents as a dictionary.
|
|
233
|
-
"""
|
|
234
|
-
if os.path.exists(path):
|
|
235
|
-
with open(path, 'r') as f:
|
|
236
|
-
return yaml.safe_load(f)
|
|
237
|
-
else:
|
|
238
|
-
raise FileNotFoundError(f"File not found: {path}")
|
|
164
|
+
return round(train_ratio, 2), round(val_ratio, 2), round(test_ratio, 2)
|
|
@@ -7,9 +7,9 @@ class EarlyStoppingChecker:
|
|
|
7
7
|
|
|
8
8
|
def __init__(self,
|
|
9
9
|
max_nan_epochs=15,
|
|
10
|
-
plateau_patience=
|
|
11
|
-
plateau_min_delta=0.
|
|
12
|
-
stagnation_patience=
|
|
10
|
+
plateau_patience=10,
|
|
11
|
+
plateau_min_delta=0.005,
|
|
12
|
+
stagnation_patience=30,
|
|
13
13
|
convergence_window=5,
|
|
14
14
|
convergence_threshold=0.005,
|
|
15
15
|
val_interval=15,
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
+
from copick_utils.io import readers
|
|
1
2
|
import matplotlib.colors as mcolors
|
|
2
3
|
from typing import Optional, List
|
|
3
4
|
import matplotlib.pyplot as plt
|
|
4
|
-
from octopi import io
|
|
5
5
|
import numpy as np
|
|
6
6
|
|
|
7
7
|
# Define the plotting function
|
|
@@ -76,7 +76,7 @@ def show_tomo_points(tomo, run, objects, user_id, vol_slice, session_id = None,
|
|
|
76
76
|
|
|
77
77
|
for name,_,_ in objects:
|
|
78
78
|
try:
|
|
79
|
-
coordinates =
|
|
79
|
+
coordinates = readers.coordinates(run, name=name, user_id=user_id, session_id=session_id)
|
|
80
80
|
close_points = coordinates[np.abs(coordinates[:, 0] - vol_slice) <= slice_proximity_threshold]
|
|
81
81
|
plt.scatter(close_points[:, 2], close_points[:, 1], label=name, s=15)
|
|
82
82
|
except:
|
|
@@ -94,7 +94,7 @@ def compare_tomo_points(tomo, run, objects, vol_slice, user_id1, user_id2,
|
|
|
94
94
|
|
|
95
95
|
for name,_,_ in objects:
|
|
96
96
|
try:
|
|
97
|
-
coordinates =
|
|
97
|
+
coordinates = readers.coordinates(run, name=name, user_id=user_id1, session_id=session_id1)
|
|
98
98
|
close_points = coordinates[np.abs(coordinates[:, 0] - vol_slice) <= slice_proximity_threshold]
|
|
99
99
|
plt.scatter(close_points[:, 2], close_points[:, 1], label=name, s=15)
|
|
100
100
|
except:
|
|
@@ -106,7 +106,7 @@ def compare_tomo_points(tomo, run, objects, vol_slice, user_id1, user_id2,
|
|
|
106
106
|
|
|
107
107
|
for name,_,_ in objects:
|
|
108
108
|
try:
|
|
109
|
-
coordinates =
|
|
109
|
+
coordinates = readers.coordinates(run, name=name, user_id=user_id2, session_id=session_id2)
|
|
110
110
|
close_points = coordinates[np.abs(coordinates[:, 0] - vol_slice) <= slice_proximity_threshold]
|
|
111
111
|
plt.scatter(close_points[:, 2], close_points[:, 1], label=name, s=15)
|
|
112
112
|
except:
|
octopi/workflows.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
from octopi.extract.localize import process_localization
|
|
2
|
+
import octopi.processing.evaluate as octopi_evaluate
|
|
3
|
+
from monai.metrics import ConfusionMatrixMetric
|
|
4
|
+
from octopi.models import common as builder
|
|
5
|
+
from octopi.pytorch import segmentation
|
|
6
|
+
from octopi.datasets import generators
|
|
7
|
+
from octopi.pytorch import trainer
|
|
8
|
+
import multiprocess as mp
|
|
9
|
+
import copick, torch, os
|
|
10
|
+
from octopi.utils import io
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
|
|
13
|
+
def train(config, target_info, tomo_algorithm, voxel_size, loss_function,
|
|
14
|
+
model_config = None, model_weights = None, trainRunIDs = None, validateRunIDs = None,
|
|
15
|
+
model_save_path = 'results', best_metric = 'fBeta2', num_epochs = 1000, use_ema = True):
|
|
16
|
+
"""
|
|
17
|
+
Train a UNet Model for Segmentation
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
config (str): Path to the Copick Config File
|
|
21
|
+
target_info (list): List containing the target user ID, target session ID, and target algorithm
|
|
22
|
+
tomo_algorithm (str): The tomographic algorithm to use for segmentation
|
|
23
|
+
voxel_size (float): The voxel size of the data
|
|
24
|
+
loss_function (str): The loss function to use for training
|
|
25
|
+
model_config (dict): The model configuration
|
|
26
|
+
model_weights (str): The path to the model weights
|
|
27
|
+
trainRunIDs (list): The list of run IDs to use for training
|
|
28
|
+
validateRunIDs (list): The list of run IDs to use for validation
|
|
29
|
+
model_save_path (str): The path to save the model
|
|
30
|
+
best_metric (str): The metric to use for early stopping
|
|
31
|
+
num_epochs (int): The number of epochs to train for
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
# If No Model Configuration is Provided, Use the Default Configuration
|
|
35
|
+
if model_config is None:
|
|
36
|
+
root = copick.from_file(config)
|
|
37
|
+
model_config = {
|
|
38
|
+
'architecture': 'Unet',
|
|
39
|
+
'num_classes': root.pickable_objects[-1].label + 1,
|
|
40
|
+
'dim_in': 80,
|
|
41
|
+
'strides': [2, 2, 1],
|
|
42
|
+
'channels': [48, 64, 80, 80],
|
|
43
|
+
'dropout': 0.0, 'num_res_units': 1,
|
|
44
|
+
}
|
|
45
|
+
print('No Model Configuration Provided, Using Default Configuration')
|
|
46
|
+
print(model_config)
|
|
47
|
+
|
|
48
|
+
data_generator = generators.TrainLoaderManager(
|
|
49
|
+
config,
|
|
50
|
+
target_info[0],
|
|
51
|
+
target_session_id = target_info[2],
|
|
52
|
+
target_user_id = target_info[1],
|
|
53
|
+
tomo_algorithm = tomo_algorithm,
|
|
54
|
+
voxel_size = voxel_size,
|
|
55
|
+
Nclasses = model_config['num_classes'],
|
|
56
|
+
tomo_batch_size = 15 )
|
|
57
|
+
|
|
58
|
+
data_generator.get_data_splits(
|
|
59
|
+
trainRunIDs = trainRunIDs,
|
|
60
|
+
validateRunIDs = validateRunIDs,
|
|
61
|
+
train_ratio = 0.9, val_ratio = 0.1, test_ratio = 0.0,
|
|
62
|
+
create_test_dataset = False)
|
|
63
|
+
|
|
64
|
+
# Get the reload frequency
|
|
65
|
+
data_generator.get_reload_frequency(num_epochs)
|
|
66
|
+
|
|
67
|
+
# Monai Functions
|
|
68
|
+
metrics_function = ConfusionMatrixMetric(include_background=False, metric_name=["recall",'precision','f1 score'], reduction="none")
|
|
69
|
+
|
|
70
|
+
# Build the Model
|
|
71
|
+
model_builder = builder.get_model(model_config['architecture'])
|
|
72
|
+
model = model_builder.build_model(model_config)
|
|
73
|
+
|
|
74
|
+
# Load the Model Weights if Provided
|
|
75
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
76
|
+
if model_weights:
|
|
77
|
+
state_dict = torch.load(model_weights, map_location=device, weights_only=True)
|
|
78
|
+
model.load_state_dict(state_dict)
|
|
79
|
+
model.to(device)
|
|
80
|
+
|
|
81
|
+
# Optimizer
|
|
82
|
+
optimizer = torch.optim.Adam(model.parameters(), 1e-3)
|
|
83
|
+
|
|
84
|
+
# Create UNet-Trainer
|
|
85
|
+
model_trainer = trainer.ModelTrainer(
|
|
86
|
+
model, device, loss_function, metrics_function, optimizer,
|
|
87
|
+
use_ema = use_ema
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
results = model_trainer.train(
|
|
91
|
+
data_generator, model_save_path, max_epochs=num_epochs,
|
|
92
|
+
crop_size=model_config['dim_in'], my_num_samples=16,
|
|
93
|
+
val_interval=10, best_metric=best_metric, verbose=True
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Save parameters and results
|
|
97
|
+
parameters_save_name = os.path.join(model_save_path, "model_config.yaml")
|
|
98
|
+
io.save_parameters_to_yaml(model_builder, model_trainer, data_generator, parameters_save_name)
|
|
99
|
+
|
|
100
|
+
# TODO: Write Results to Zarr or Another File Format?
|
|
101
|
+
results_save_name = os.path.join(model_save_path, "results.json")
|
|
102
|
+
io.save_results_to_json(results, results_save_name)
|
|
103
|
+
|
|
104
|
+
def segment(config, tomo_algorithm, voxel_size, model_weights, model_config,
|
|
105
|
+
seg_info = ['predict', 'octopi', '1'], use_tta = False, run_ids = None):
|
|
106
|
+
"""
|
|
107
|
+
Segment a Dataset using a Trained Model or Ensemble of Models
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
config (str): Path to the Copick Config File
|
|
111
|
+
tomo_algorithm (str): The tomographic algorithm to use for segmentation
|
|
112
|
+
voxel_size (float): The voxel size of the data
|
|
113
|
+
model_weights (str, list): The path to the model weights or a list of paths to the model weights
|
|
114
|
+
model_config (str, list): The model configuration or a list of model configurations
|
|
115
|
+
seg_info (list): The segmentation information
|
|
116
|
+
use_tta (bool): Whether to use test time augmentation
|
|
117
|
+
run_ids (list): The list of run IDs to use for segmentation
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
# Initialize the Predictor
|
|
121
|
+
predict = segmentation.Predictor(
|
|
122
|
+
config,
|
|
123
|
+
model_config,
|
|
124
|
+
model_weights,
|
|
125
|
+
apply_tta = use_tta
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# Run batch prediction
|
|
129
|
+
predict.batch_predict(
|
|
130
|
+
runIDs=run_ids,
|
|
131
|
+
num_tomos_per_batch=15,
|
|
132
|
+
tomo_algorithm=tomo_algorithm,
|
|
133
|
+
voxel_spacing=voxel_size,
|
|
134
|
+
segmentation_name=seg_info[0],
|
|
135
|
+
segmentation_user_id=seg_info[1],
|
|
136
|
+
segmentation_session_id=seg_info[2]
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
def localize(config, voxel_size, seg_info, pick_user_id, pick_session_id, n_procs = 16,
|
|
140
|
+
method = 'watershed', filter_size = 10, radius_min_scale = 0.4, radius_max_scale = 1.0,
|
|
141
|
+
run_ids = None):
|
|
142
|
+
"""
|
|
143
|
+
Extract 3D Coordinates from the Segmentation Maps
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
config (str): Path to the Copick Config File
|
|
147
|
+
voxel_size (float): The voxel size of the data
|
|
148
|
+
seg_info (list): The segmentation information
|
|
149
|
+
pick_user_id (str): The user ID of the pick
|
|
150
|
+
pick_session_id (str): The session ID of the pick
|
|
151
|
+
n_procs (int): The number of processes to use for parallelization
|
|
152
|
+
method (str): The method to use for localization
|
|
153
|
+
filter_size (int): The filter size to use for localization
|
|
154
|
+
radius_min_scale (float): The minimum radius scale to use for localization
|
|
155
|
+
radius_max_scale (float): The maximum radius scale to use for localization
|
|
156
|
+
run_ids (list): The list of run IDs to use for localization
|
|
157
|
+
"""
|
|
158
|
+
|
|
159
|
+
# Load the Copick Config
|
|
160
|
+
root = copick.from_file(config)
|
|
161
|
+
|
|
162
|
+
# Get objects that can be Picked
|
|
163
|
+
objects = [(obj.name, obj.label, obj.radius) for obj in root.pickable_objects if obj.is_particle]
|
|
164
|
+
|
|
165
|
+
# Get all RunIDs
|
|
166
|
+
if run_ids is None:
|
|
167
|
+
run_ids = [run.name for run in root.runs]
|
|
168
|
+
n_run_ids = len(run_ids)
|
|
169
|
+
|
|
170
|
+
# Run Localization - Main Parallelization Loop
|
|
171
|
+
print(f"Using {n_procs} processes to parallelize across {n_run_ids} run IDs.")
|
|
172
|
+
with mp.Pool(processes=n_procs) as pool:
|
|
173
|
+
with tqdm(total=n_run_ids, desc="Localization", unit="run") as pbar:
|
|
174
|
+
worker_func = lambda run_id: process_localization(
|
|
175
|
+
root.get_run(run_id),
|
|
176
|
+
objects,
|
|
177
|
+
seg_info,
|
|
178
|
+
method,
|
|
179
|
+
voxel_size,
|
|
180
|
+
filter_size,
|
|
181
|
+
radius_min_scale,
|
|
182
|
+
radius_max_scale,
|
|
183
|
+
pick_session_id,
|
|
184
|
+
pick_user_id
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
for _ in pool.imap_unordered(worker_func, run_ids, chunksize=1):
|
|
188
|
+
pbar.update(1)
|
|
189
|
+
|
|
190
|
+
print('Localization Complete!')
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def evaluate(config,
|
|
194
|
+
gt_user_id, gt_session_id,
|
|
195
|
+
pred_user_id, pred_session_id,
|
|
196
|
+
run_ids = None, distance_threshold = 0.5, save_path = None):
|
|
197
|
+
"""
|
|
198
|
+
Evaluate the Localization on a Dataset
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
config (str): Path to the Copick Config File
|
|
202
|
+
gt_user_id (str): The user ID of the ground truth
|
|
203
|
+
gt_session_id (str): The session ID of the ground truth
|
|
204
|
+
pred_user_id (str): The user ID of the predicted coordinates
|
|
205
|
+
pred_session_id (str): The session ID of the predicted coordinates
|
|
206
|
+
run_ids (list): The list of run IDs to use for evaluation
|
|
207
|
+
distance_threshold (float): The distance threshold to use for evaluation
|
|
208
|
+
save_path (str): The path to save the evaluation results
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
print('Running Evaluation on the Following Query:')
|
|
212
|
+
print(f'Distance Threshold: {distance_threshold}')
|
|
213
|
+
print(f'GT User ID: {gt_user_id}, GT Session ID: {gt_session_id}')
|
|
214
|
+
print(f'Pred User ID: {pred_user_id}, Pred Session ID: {pred_session_id}')
|
|
215
|
+
print(f'Run IDs: {run_ids}')
|
|
216
|
+
|
|
217
|
+
# Load the Copick Config
|
|
218
|
+
root = copick.from_file(config)
|
|
219
|
+
|
|
220
|
+
# For Now Lets Assume Object Names are None..
|
|
221
|
+
object_names = None
|
|
222
|
+
|
|
223
|
+
# Run Evaluation
|
|
224
|
+
eval = octopi_evaluate.evaluator(
|
|
225
|
+
config,
|
|
226
|
+
gt_user_id,
|
|
227
|
+
gt_session_id,
|
|
228
|
+
pred_user_id,
|
|
229
|
+
pred_session_id,
|
|
230
|
+
object_names=object_names
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
eval.run(
|
|
234
|
+
distance_threshold_scale=distance_threshold,
|
|
235
|
+
runIDs=run_ids, save_path=save_path
|
|
236
|
+
)
|
|
@@ -1,10 +1,13 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: octopi
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.2.0
|
|
4
4
|
Summary: Model architecture exploration for cryoET particle picking
|
|
5
|
+
Project-URL: Homepage, https://github.com/chanzuckerberg/octopi
|
|
6
|
+
Project-URL: Documentation, https://chanzuckerberg.github.io/octopi/
|
|
7
|
+
Project-URL: Issues, https://github.com/chanzuckerberg/octopi/issues
|
|
8
|
+
Author: Jonathan Schwartz, Kevin Zhao, Daniel Ji, Utz Ermel
|
|
5
9
|
License: MIT
|
|
6
|
-
|
|
7
|
-
Requires-Python: >=3.9,<4.0
|
|
10
|
+
License-File: LICENSE
|
|
8
11
|
Classifier: License :: OSI Approved :: MIT License
|
|
9
12
|
Classifier: Programming Language :: Python :: 3
|
|
10
13
|
Classifier: Programming Language :: Python :: 3.9
|
|
@@ -12,28 +15,41 @@ Classifier: Programming Language :: Python :: 3.10
|
|
|
12
15
|
Classifier: Programming Language :: Python :: 3.11
|
|
13
16
|
Classifier: Programming Language :: Python :: 3.12
|
|
14
17
|
Classifier: Programming Language :: Python :: 3.13
|
|
18
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
19
|
+
Classifier: Topic :: Scientific/Engineering :: Image Processing
|
|
20
|
+
Classifier: Topic :: Scientific/Engineering :: Image Recognition
|
|
21
|
+
Requires-Python: >=3.9
|
|
15
22
|
Requires-Dist: copick
|
|
23
|
+
Requires-Dist: copick-utils
|
|
16
24
|
Requires-Dist: ipywidgets
|
|
17
25
|
Requires-Dist: kaleido
|
|
18
26
|
Requires-Dist: matplotlib
|
|
19
|
-
Requires-Dist: mlflow
|
|
20
|
-
Requires-Dist: monai
|
|
27
|
+
Requires-Dist: mlflow
|
|
28
|
+
Requires-Dist: monai
|
|
21
29
|
Requires-Dist: mrcfile
|
|
22
30
|
Requires-Dist: multiprocess
|
|
23
31
|
Requires-Dist: nibabel
|
|
24
|
-
Requires-Dist: optuna
|
|
32
|
+
Requires-Dist: optuna
|
|
25
33
|
Requires-Dist: optuna-integration[botorch,pytorch-lightning]
|
|
26
34
|
Requires-Dist: pandas
|
|
27
|
-
Requires-Dist: plotly
|
|
28
35
|
Requires-Dist: python-dotenv
|
|
29
|
-
Requires-Dist:
|
|
30
|
-
Requires-Dist: requests (>=2.25.1,<3.0.0)
|
|
31
|
-
Requires-Dist: seaborn
|
|
36
|
+
Requires-Dist: requests
|
|
32
37
|
Requires-Dist: torch-ema
|
|
33
38
|
Requires-Dist: tqdm
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
39
|
+
Provides-Extra: dev
|
|
40
|
+
Requires-Dist: black>=24.8.0; extra == 'dev'
|
|
41
|
+
Requires-Dist: pre-commit>=3.8.0; extra == 'dev'
|
|
42
|
+
Requires-Dist: pytest>=6.2.3; extra == 'dev'
|
|
43
|
+
Requires-Dist: ruff>=0.6.4; extra == 'dev'
|
|
44
|
+
Provides-Extra: docs
|
|
45
|
+
Requires-Dist: mkdocs; extra == 'docs'
|
|
46
|
+
Requires-Dist: mkdocs-awesome-pages-plugin; extra == 'docs'
|
|
47
|
+
Requires-Dist: mkdocs-git-authors-plugin; extra == 'docs'
|
|
48
|
+
Requires-Dist: mkdocs-git-committers-plugin-2; extra == 'docs'
|
|
49
|
+
Requires-Dist: mkdocs-git-revision-date-localized-plugin; extra == 'docs'
|
|
50
|
+
Requires-Dist: mkdocs-material; extra == 'docs'
|
|
51
|
+
Requires-Dist: mkdocs-minify-plugin; extra == 'docs'
|
|
52
|
+
Requires-Dist: mkdocs-redirects; extra == 'docs'
|
|
37
53
|
Description-Content-Type: text/markdown
|
|
38
54
|
|
|
39
55
|
# OCTOPI 🐙🐙🐙
|
|
@@ -63,10 +79,16 @@ Our deep learning-based pipeline streamlines the training and execution of 3D au
|
|
|
63
79
|
|
|
64
80
|
### Installation
|
|
65
81
|
|
|
82
|
+
Octopi is availableon PyPI and can be installed using pip:
|
|
66
83
|
```bash
|
|
67
84
|
pip install octopi
|
|
68
85
|
```
|
|
69
86
|
|
|
87
|
+
⚠️ **Note**: One of the current dependencies is currently not working with pip 25. To temporarily reduce the pip version, run:
|
|
88
|
+
```bash
|
|
89
|
+
pip install --upgrade "pip<25"
|
|
90
|
+
```
|
|
91
|
+
|
|
70
92
|
### Basic Usage
|
|
71
93
|
|
|
72
94
|
octopi provides two main command-line interfaces:
|
|
@@ -74,35 +96,25 @@ octopi provides two main command-line interfaces:
|
|
|
74
96
|
```bash
|
|
75
97
|
# Main CLI for training, inference, and data processing
|
|
76
98
|
octopi --help
|
|
77
|
-
```
|
|
78
99
|
|
|
79
|
-
The main `octopi` command provides subcommands for:
|
|
80
|
-
- Data import and preprocessing
|
|
81
|
-
- Training label preparation
|
|
82
|
-
- Model training and exploration
|
|
83
|
-
- Inference and particle localization
|
|
84
|
-
|
|
85
|
-
```bash
|
|
86
100
|
# HPC-specific CLI for submitting jobs to SLURM clusters
|
|
87
101
|
octopi-slurm --help
|
|
88
102
|
```
|
|
89
103
|
|
|
90
|
-
The `octopi-slurm` command provides utilities for:
|
|
91
|
-
- Submitting training jobs to SLURM clusters
|
|
92
|
-
- Managing distributed inference tasks
|
|
93
|
-
- Handling batch processing on HPC systems
|
|
94
|
-
|
|
95
104
|
## 📚 Documentation
|
|
96
105
|
|
|
97
106
|
For detailed documentation, tutorials, CLI and API reference, visit our [documentation](https://chanzuckerberg.github.io/octopi/).
|
|
98
107
|
|
|
99
108
|
## 🤝 Contributing
|
|
100
109
|
|
|
101
|
-
|
|
110
|
+
## Code of Conduct
|
|
111
|
+
|
|
112
|
+
This project adheres to the Contributor Covenant [code of conduct](https://github.com/chanzuckerberg/.github/blob/master/CODE_OF_CONDUCT.md).
|
|
113
|
+
By participating, you are expected to uphold this code.
|
|
114
|
+
Please report unacceptable behavior to [opensource@chanzuckerberg.com](mailto:opensource@chanzuckerberg.com).
|
|
102
115
|
|
|
103
116
|
## 🔒 Security
|
|
104
117
|
|
|
105
118
|
If you believe you have found a security issue, please responsibly disclose by contacting us at security@chanzuckerberg.com.
|
|
106
119
|
|
|
107
120
|
|
|
108
|
-
|