octopi 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- octopi/__init__.py +7 -0
- octopi/datasets/__init__.py +0 -0
- octopi/datasets/augment.py +83 -0
- octopi/datasets/cached_datset.py +113 -0
- octopi/datasets/dataset.py +19 -0
- octopi/datasets/generators.py +458 -0
- octopi/datasets/io.py +200 -0
- octopi/datasets/mixup.py +49 -0
- octopi/datasets/multi_config_generator.py +252 -0
- octopi/entry_points/__init__.py +0 -0
- octopi/entry_points/common.py +119 -0
- octopi/entry_points/create_slurm_submission.py +251 -0
- octopi/entry_points/groups.py +152 -0
- octopi/entry_points/run_create_targets.py +234 -0
- octopi/entry_points/run_evaluate.py +99 -0
- octopi/entry_points/run_extract_mb_picks.py +191 -0
- octopi/entry_points/run_extract_midpoint.py +143 -0
- octopi/entry_points/run_localize.py +176 -0
- octopi/entry_points/run_optuna.py +161 -0
- octopi/entry_points/run_segment.py +154 -0
- octopi/entry_points/run_train.py +189 -0
- octopi/extract/__init__.py +0 -0
- octopi/extract/localize.py +217 -0
- octopi/extract/membranebound_extract.py +263 -0
- octopi/extract/midpoint_extract.py +193 -0
- octopi/main.py +33 -0
- octopi/models/AttentionUnet.py +56 -0
- octopi/models/MedNeXt.py +111 -0
- octopi/models/ModelTemplate.py +36 -0
- octopi/models/SegResNet.py +92 -0
- octopi/models/Unet.py +59 -0
- octopi/models/UnetPlusPlus.py +47 -0
- octopi/models/__init__.py +0 -0
- octopi/models/common.py +72 -0
- octopi/processing/__init__.py +0 -0
- octopi/processing/create_targets_from_picks.py +224 -0
- octopi/processing/downloader.py +138 -0
- octopi/processing/downsample.py +125 -0
- octopi/processing/evaluate.py +302 -0
- octopi/processing/importers.py +116 -0
- octopi/processing/segmentation_from_picks.py +167 -0
- octopi/pytorch/__init__.py +0 -0
- octopi/pytorch/hyper_search.py +244 -0
- octopi/pytorch/model_search_submitter.py +291 -0
- octopi/pytorch/segmentation.py +363 -0
- octopi/pytorch/segmentation_multigpu.py +162 -0
- octopi/pytorch/trainer.py +465 -0
- octopi/pytorch_lightning/__init__.py +0 -0
- octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
- octopi/pytorch_lightning/train_pl.py +244 -0
- octopi/utils/__init__.py +0 -0
- octopi/utils/config.py +57 -0
- octopi/utils/io.py +215 -0
- octopi/utils/losses.py +86 -0
- octopi/utils/parsers.py +162 -0
- octopi/utils/progress.py +78 -0
- octopi/utils/stopping_criteria.py +143 -0
- octopi/utils/submit_slurm.py +95 -0
- octopi/utils/visualization_tools.py +290 -0
- octopi/workflows.py +262 -0
- octopi-1.4.0.dist-info/METADATA +119 -0
- octopi-1.4.0.dist-info/RECORD +65 -0
- octopi-1.4.0.dist-info/WHEEL +4 -0
- octopi-1.4.0.dist-info/entry_points.txt +3 -0
- octopi-1.4.0.dist-info/licenses/LICENSE +41 -0
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
from octopi.entry_points import common
|
|
2
|
+
from octopi.utils import parsers
|
|
3
|
+
import rich_click as click
|
|
4
|
+
|
|
5
|
+
def save_parameters(config: tuple,
|
|
6
|
+
target_info: tuple,
|
|
7
|
+
tomo_alg: str,
|
|
8
|
+
voxel_size: float,
|
|
9
|
+
model_type: str,
|
|
10
|
+
mlflow_experiment_name: str,
|
|
11
|
+
random_seed: int,
|
|
12
|
+
num_trials: int,
|
|
13
|
+
best_metric: str,
|
|
14
|
+
num_epochs: int,
|
|
15
|
+
tomo_batch_size: int,
|
|
16
|
+
trainRunIDs: list,
|
|
17
|
+
validateRunIDs: list,
|
|
18
|
+
data_split: str,
|
|
19
|
+
output_path: str):
|
|
20
|
+
"""
|
|
21
|
+
Save the Optuna search parameters to a YAML file.
|
|
22
|
+
"""
|
|
23
|
+
import octopi.utils.io as io
|
|
24
|
+
import pprint
|
|
25
|
+
|
|
26
|
+
# Organize parameters into categories
|
|
27
|
+
params = {
|
|
28
|
+
"input": {
|
|
29
|
+
"copick_config": config,
|
|
30
|
+
"target_info": target_info,
|
|
31
|
+
"tomo_algorithm": tomo_alg,
|
|
32
|
+
"voxel_size": voxel_size,
|
|
33
|
+
},
|
|
34
|
+
"optimization": {
|
|
35
|
+
"model_type": model_type,
|
|
36
|
+
"mlflow_experiment_name": mlflow_experiment_name,
|
|
37
|
+
"random_seed": random_seed,
|
|
38
|
+
"num_trials": num_trials,
|
|
39
|
+
"best_metric": best_metric
|
|
40
|
+
},
|
|
41
|
+
"training": {
|
|
42
|
+
"num_epochs": num_epochs,
|
|
43
|
+
"tomo_batch_size": tomo_batch_size,
|
|
44
|
+
"trainRunIDs": trainRunIDs,
|
|
45
|
+
"validateRunIDs": validateRunIDs,
|
|
46
|
+
"data_split": data_split
|
|
47
|
+
}
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
# Print the parameters
|
|
51
|
+
print(f"\nParameters for Model Architecture Search:")
|
|
52
|
+
pprint.pprint(params); print()
|
|
53
|
+
|
|
54
|
+
# Save to YAML file
|
|
55
|
+
io.save_parameters_yaml(params, output_path)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@click.command('model-explore', help="Perform model architecture search with Optuna and MLflow integration")
|
|
59
|
+
# Training Arguments
|
|
60
|
+
@click.option('--random-seed', type=int, default=42,
|
|
61
|
+
help="Random seed for reproducibility")
|
|
62
|
+
@common.train_parameters(octopi=True)
|
|
63
|
+
# Model Arguments
|
|
64
|
+
@click.option('--model-type', type=click.Choice(['Unet', 'AttentionUnet', 'MedNeXt', 'SegResNet'], case_sensitive=False),
|
|
65
|
+
default='Unet',
|
|
66
|
+
help="Model type to use for training")
|
|
67
|
+
# Input Arguments
|
|
68
|
+
@click.option('-split', '--data-split', type=str, default='0.8',
|
|
69
|
+
help="Data split ratios. Either a single value (e.g., '0.8' for 80/20/0 split) or two comma-separated values (e.g., '0.7,0.1' for 70/10/20 split)")
|
|
70
|
+
@click.option('-vruns', '--validateRunIDs', type=str, default=None,
|
|
71
|
+
callback=lambda ctx, param, value: parsers.parse_list(value) if value else None,
|
|
72
|
+
help="List of validation run IDs, e.g., run3,run4 or [run3,run4]")
|
|
73
|
+
@click.option('-truns', '--trainRunIDs', type=str, default=None,
|
|
74
|
+
callback=lambda ctx, param, value: parsers.parse_list(value) if value else None,
|
|
75
|
+
help="List of training run IDs, e.g., run1,run2 or [run1,run2]")
|
|
76
|
+
@click.option('--mlflow-experiment-name', type=str, default="model-search",
|
|
77
|
+
help="Name of the MLflow experiment")
|
|
78
|
+
@click.option('-alg', '--tomo-alg', type=str, default='wbp',
|
|
79
|
+
help="Tomogram algorithm used for training")
|
|
80
|
+
@click.option('-tinfo', '--target-info', type=str, default="targets,octopi,1",
|
|
81
|
+
callback=lambda ctx, param, value: parsers.parse_target(value),
|
|
82
|
+
help="Target information, e.g., 'name' or 'name,user_id,session_id'")
|
|
83
|
+
@common.config_parameters(single_config=False)
|
|
84
|
+
def cli(config, voxel_size, target_info, tomo_alg, mlflow_experiment_name,
|
|
85
|
+
trainrunids, validaterunids, data_split,
|
|
86
|
+
model_type,
|
|
87
|
+
num_epochs, val_interval, tomo_batch_size, best_metric, num_trials, random_seed):
|
|
88
|
+
"""
|
|
89
|
+
CLI entry point for running optuna model architecture search.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
run_model_explore(
|
|
93
|
+
config, voxel_size, target_info, tomo_alg, mlflow_experiment_name,
|
|
94
|
+
trainrunids, validaterunids, data_split, model_type,
|
|
95
|
+
num_epochs, val_interval, tomo_batch_size, best_metric, num_trials, random_seed
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
def run_model_explore(config, voxel_size, target_info, tomo_alg, mlflow_experiment_name,
|
|
99
|
+
trainrunids, validaterunids, data_split, model_type,
|
|
100
|
+
num_epochs, val_interval, tomo_batch_size, best_metric, num_trials, random_seed):
|
|
101
|
+
"""
|
|
102
|
+
Run the model exploration.
|
|
103
|
+
"""
|
|
104
|
+
from octopi.pytorch.model_search_submitter import ModelSearchSubmit
|
|
105
|
+
import os
|
|
106
|
+
|
|
107
|
+
# Parse the CoPick configuration paths
|
|
108
|
+
if len(config) > 1:
|
|
109
|
+
copick_configs = parsers.parse_copick_configs(config)
|
|
110
|
+
else:
|
|
111
|
+
copick_configs = config[0]
|
|
112
|
+
|
|
113
|
+
# Create the model exploration directory
|
|
114
|
+
os.makedirs(f'explore_results_{model_type}', exist_ok=True)
|
|
115
|
+
|
|
116
|
+
# Save parameters
|
|
117
|
+
save_parameters(
|
|
118
|
+
config=config,
|
|
119
|
+
target_info=target_info,
|
|
120
|
+
tomo_alg=tomo_alg,
|
|
121
|
+
voxel_size=voxel_size,
|
|
122
|
+
model_type=model_type,
|
|
123
|
+
mlflow_experiment_name=mlflow_experiment_name,
|
|
124
|
+
random_seed=random_seed,
|
|
125
|
+
num_trials=num_trials,
|
|
126
|
+
best_metric=best_metric,
|
|
127
|
+
num_epochs=num_epochs,
|
|
128
|
+
tomo_batch_size=tomo_batch_size,
|
|
129
|
+
trainRunIDs=trainrunids,
|
|
130
|
+
validateRunIDs=validaterunids,
|
|
131
|
+
data_split=data_split,
|
|
132
|
+
output_path=f'explore_results_{model_type}/octopi.yaml'
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# Call the function with parsed arguments
|
|
136
|
+
search = ModelSearchSubmit(
|
|
137
|
+
copick_config=copick_configs,
|
|
138
|
+
target_name=target_info[0],
|
|
139
|
+
target_user_id=target_info[1],
|
|
140
|
+
target_session_id=target_info[2],
|
|
141
|
+
tomo_algorithm=tomo_alg,
|
|
142
|
+
voxel_size=voxel_size,
|
|
143
|
+
model_type=model_type,
|
|
144
|
+
mlflow_experiment_name=mlflow_experiment_name,
|
|
145
|
+
random_seed=random_seed,
|
|
146
|
+
num_epochs=num_epochs,
|
|
147
|
+
num_trials=num_trials,
|
|
148
|
+
trainRunIDs=trainrunids,
|
|
149
|
+
validateRunIDs=validaterunids,
|
|
150
|
+
tomo_batch_size=tomo_batch_size,
|
|
151
|
+
best_metric=best_metric,
|
|
152
|
+
val_interval=val_interval,
|
|
153
|
+
data_split=data_split
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# Run the model search
|
|
157
|
+
search.run_model_search()
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
if __name__ == "__main__":
|
|
161
|
+
cli()
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
from octopi.entry_points import common
|
|
2
|
+
from typing import List, Tuple
|
|
3
|
+
import rich_click as click
|
|
4
|
+
|
|
5
|
+
def inference(
|
|
6
|
+
copick_config_path: str,
|
|
7
|
+
model_weights: str,
|
|
8
|
+
model_config: str,
|
|
9
|
+
seg_info: Tuple[str,str,str],
|
|
10
|
+
voxel_size: float,
|
|
11
|
+
tomo_algorithm: str,
|
|
12
|
+
tomo_batch_size: int,
|
|
13
|
+
run_ids: List[str],
|
|
14
|
+
):
|
|
15
|
+
"""
|
|
16
|
+
Perform segmentation inference using a model on provided tomograms.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
copick_config_path (str): Path to CoPick configuration file.
|
|
20
|
+
run_ids (List[str]): List of tomogram run IDs for inference.
|
|
21
|
+
model_weights (str): Path to the trained model weights file.
|
|
22
|
+
channels (List[int]): List of channel sizes for each layer.
|
|
23
|
+
strides (List[int]): List of strides for the layers.
|
|
24
|
+
res_units (int): Number of residual units for the model.
|
|
25
|
+
voxel_size (float): Voxel size for tomogram reconstruction.
|
|
26
|
+
tomo_algorithm (str): Tomogram reconstruction algorithm to use.
|
|
27
|
+
segmentation_name (str): Name for the segmentation output.
|
|
28
|
+
segmentation_user_id (str): User ID associated with the segmentation.
|
|
29
|
+
segmentation_session_id (str): Session ID for this segmentation run.
|
|
30
|
+
"""
|
|
31
|
+
from octopi.pytorch import segmentation
|
|
32
|
+
import torch
|
|
33
|
+
|
|
34
|
+
gpu_count = torch.cuda.device_count()
|
|
35
|
+
print(f"Number of GPUs available: {gpu_count}")
|
|
36
|
+
|
|
37
|
+
if ',' in model_weights:
|
|
38
|
+
model_weights = model_weights.split(',')
|
|
39
|
+
if ',' in model_config:
|
|
40
|
+
model_config = model_config.split(',')
|
|
41
|
+
if isinstance(model_weights, list) and isinstance(model_config, list):
|
|
42
|
+
if len(model_weights) != len(model_config):
|
|
43
|
+
raise ValueError("Number of model weights and model configs must match for ensemble prediction.")
|
|
44
|
+
print("\nUsing Model Ensemble (Soup) Segmentation.")
|
|
45
|
+
print('Model Weights:', model_weights)
|
|
46
|
+
print('Model Configs:', model_config)
|
|
47
|
+
else:
|
|
48
|
+
print("Using Single Model Segmentation.")
|
|
49
|
+
|
|
50
|
+
if gpu_count > 1:
|
|
51
|
+
print("Using Multi-GPU Predictor.")
|
|
52
|
+
predict = segmentation.MultiGPUPredictor(
|
|
53
|
+
copick_config_path,
|
|
54
|
+
model_config,
|
|
55
|
+
model_weights
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# Run Multi-GPU inference
|
|
59
|
+
predict.multi_gpu_inference(
|
|
60
|
+
runIDs=run_ids,
|
|
61
|
+
tomo_algorithm=tomo_algorithm,
|
|
62
|
+
voxel_spacing=voxel_size,
|
|
63
|
+
segmentation_name=seg_info[0],
|
|
64
|
+
segmentation_user_id=seg_info[1],
|
|
65
|
+
segmentation_session_id=seg_info[2],
|
|
66
|
+
save=True
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
else:
|
|
70
|
+
print("Using Single-GPU Predictor.")
|
|
71
|
+
predict = segmentation.Predictor(
|
|
72
|
+
copick_config_path,
|
|
73
|
+
model_config,
|
|
74
|
+
model_weights,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Run batch prediction
|
|
78
|
+
predict.batch_predict(
|
|
79
|
+
runIDs=run_ids,
|
|
80
|
+
num_tomos_per_batch=tomo_batch_size,
|
|
81
|
+
tomo_algorithm=tomo_algorithm,
|
|
82
|
+
voxel_spacing=voxel_size,
|
|
83
|
+
segmentation_name=seg_info[0],
|
|
84
|
+
segmentation_user_id=seg_info[1],
|
|
85
|
+
segmentation_session_id=seg_info[2]
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
print("Inference completed successfully.")
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@click.command('segment')
|
|
92
|
+
# Inference Arguments
|
|
93
|
+
@common.inference_parameters()
|
|
94
|
+
# Model Arguments
|
|
95
|
+
@common.inference_model_parameters()
|
|
96
|
+
# Input Arguments
|
|
97
|
+
@common.config_parameters(single_config=True)
|
|
98
|
+
def cli(config, voxel_size,
|
|
99
|
+
model_config, model_weights,
|
|
100
|
+
tomo_alg, seg_info, tomo_batch_size, run_ids):
|
|
101
|
+
"""
|
|
102
|
+
Segment volumes using trained neural network models.
|
|
103
|
+
|
|
104
|
+
It supports both single model inference and model ensembles
|
|
105
|
+
(model soups) for improved accuracy. Multi-GPU inference is automatically enabled when
|
|
106
|
+
multiple GPUs are available.
|
|
107
|
+
|
|
108
|
+
The segmentation masks are saved as zarr arrays in your copick project, organized by
|
|
109
|
+
segmentation name, user ID, and session ID for easy tracking and comparison.
|
|
110
|
+
|
|
111
|
+
\b
|
|
112
|
+
Examples:
|
|
113
|
+
# Segment with a single model
|
|
114
|
+
octopi segment -c config.json \\
|
|
115
|
+
--model-config model.yaml --model-weights model.pth \\
|
|
116
|
+
--seg-info predictions,octopi,1
|
|
117
|
+
|
|
118
|
+
\b
|
|
119
|
+
# Segment with model ensemble (comma-separated)
|
|
120
|
+
octopi segment -c config.json \\
|
|
121
|
+
--model-config model1.yaml,model2.yaml \\
|
|
122
|
+
--model-weights model1.pth,model2.pth \\
|
|
123
|
+
--seg-info ensemble,octopi,1
|
|
124
|
+
|
|
125
|
+
\b
|
|
126
|
+
# Segment specific runs only
|
|
127
|
+
octopi segment -c config.json \\
|
|
128
|
+
--model-config model.yaml --model-weights model.pth \\
|
|
129
|
+
--run-ids TS_001,TS_002,TS_003 \\
|
|
130
|
+
--tomo-batch-size 10
|
|
131
|
+
"""
|
|
132
|
+
|
|
133
|
+
# Set default values if not provided
|
|
134
|
+
seg_info = list(seg_info) # Convert tuple to list
|
|
135
|
+
if seg_info[1] is None:
|
|
136
|
+
seg_info[1] = "octopi"
|
|
137
|
+
if seg_info[2] is None:
|
|
138
|
+
seg_info[2] = "1"
|
|
139
|
+
|
|
140
|
+
# Call the inference function with parsed arguments
|
|
141
|
+
inference(
|
|
142
|
+
copick_config_path=config,
|
|
143
|
+
model_weights=model_weights,
|
|
144
|
+
model_config=model_config,
|
|
145
|
+
seg_info=seg_info,
|
|
146
|
+
voxel_size=voxel_size,
|
|
147
|
+
tomo_algorithm=tomo_alg,
|
|
148
|
+
tomo_batch_size=tomo_batch_size,
|
|
149
|
+
run_ids=run_ids,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
if __name__ == "__main__":
|
|
154
|
+
cli()
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
from typing import List, Optional, Tuple
|
|
2
|
+
from octopi.entry_points import common
|
|
3
|
+
from octopi.utils import parsers
|
|
4
|
+
from octopi import cli_context
|
|
5
|
+
import rich_click as click
|
|
6
|
+
|
|
7
|
+
# Configure rich-click
|
|
8
|
+
click.rich_click.USE_RICH_MARKUP = True
|
|
9
|
+
click.rich_click.SHOW_ARGUMENTS = True
|
|
10
|
+
click.rich_click.GROUP_ARGUMENTS_OPTIONS = True
|
|
11
|
+
|
|
12
|
+
def train_model(
|
|
13
|
+
copick_config_path: str,
|
|
14
|
+
target_info: Tuple[str, str, str],
|
|
15
|
+
tomo_algorithm: str = 'wbp',
|
|
16
|
+
voxel_size: float = 10,
|
|
17
|
+
trainRunIDs: List[str] = None,
|
|
18
|
+
validateRunIDs: List[str] = None,
|
|
19
|
+
model_config: str = None,
|
|
20
|
+
model_weights: Optional[str] = None,
|
|
21
|
+
model_save_path: str = 'results',
|
|
22
|
+
num_tomo_crops: int = 16,
|
|
23
|
+
tomo_batch_size: int = 15,
|
|
24
|
+
lr: float = 1e-3,
|
|
25
|
+
tversky_alpha: float = 0.5,
|
|
26
|
+
num_epochs: int = 100,
|
|
27
|
+
val_interval: int = 5,
|
|
28
|
+
best_metric: str = 'avg_f1',
|
|
29
|
+
data_split: str = '0.8'
|
|
30
|
+
):
|
|
31
|
+
"""
|
|
32
|
+
Train a 3D U-Net model using the specified CoPick configuration and target information.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
import matplotlib
|
|
36
|
+
# Force a headless-safe backend everywhere (must be BEFORE pyplot import)
|
|
37
|
+
matplotlib.use("Agg", force=True)
|
|
38
|
+
|
|
39
|
+
from octopi.datasets import generators, multi_config_generator
|
|
40
|
+
from monai.losses import TverskyLoss
|
|
41
|
+
from octopi.utils import parsers, io
|
|
42
|
+
from octopi.workflows import train
|
|
43
|
+
|
|
44
|
+
# Initialize the data generator to manage training and validation datasets
|
|
45
|
+
print(f'Training with {copick_config_path}\n')
|
|
46
|
+
|
|
47
|
+
# Multi-config training
|
|
48
|
+
if isinstance(copick_config_path, dict):
|
|
49
|
+
data_generator = multi_config_generator.MultiConfigTrainLoaderManager(
|
|
50
|
+
copick_config_path,
|
|
51
|
+
target_info[0],
|
|
52
|
+
target_session_id = target_info[2],
|
|
53
|
+
target_user_id = target_info[1],
|
|
54
|
+
tomo_algorithm = tomo_algorithm,
|
|
55
|
+
voxel_size = voxel_size,
|
|
56
|
+
tomo_batch_size = tomo_batch_size )
|
|
57
|
+
else: # Single-config training
|
|
58
|
+
data_generator = generators.TrainLoaderManager(
|
|
59
|
+
copick_config_path,
|
|
60
|
+
target_info[0],
|
|
61
|
+
target_session_id = target_info[2],
|
|
62
|
+
target_user_id = target_info[1],
|
|
63
|
+
tomo_algorithm = tomo_algorithm,
|
|
64
|
+
voxel_size = voxel_size,
|
|
65
|
+
tomo_batch_size = tomo_batch_size )
|
|
66
|
+
|
|
67
|
+
# Get the data splits
|
|
68
|
+
ratios = parsers.parse_data_split(data_split)
|
|
69
|
+
data_generator.get_data_splits(
|
|
70
|
+
trainRunIDs = trainRunIDs,
|
|
71
|
+
validateRunIDs = validateRunIDs,
|
|
72
|
+
train_ratio = ratios[0], val_ratio = ratios[1], test_ratio = ratios[2],
|
|
73
|
+
create_test_dataset = False)
|
|
74
|
+
|
|
75
|
+
# Get the reload frequency
|
|
76
|
+
data_generator.get_reload_frequency(num_epochs)
|
|
77
|
+
model_config['num_classes'] = data_generator.Nclasses
|
|
78
|
+
|
|
79
|
+
# Monai Functions
|
|
80
|
+
alpha = tversky_alpha
|
|
81
|
+
beta = 1 - alpha
|
|
82
|
+
loss_function = TverskyLoss(include_background=True, to_onehot_y=True, softmax=True, alpha=alpha, beta=beta)
|
|
83
|
+
|
|
84
|
+
# Train the Model
|
|
85
|
+
train(
|
|
86
|
+
data_generator, loss_function,
|
|
87
|
+
model_config = model_config, model_weights = model_weights,
|
|
88
|
+
best_metric = best_metric, num_epochs = num_epochs,
|
|
89
|
+
model_save_path = model_save_path, lr0 = lr
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def get_model_config(channels, strides, res_units, dim_in):
|
|
93
|
+
"""
|
|
94
|
+
Create a model configuration dictionary if no model configuration file is provided.
|
|
95
|
+
"""
|
|
96
|
+
model_config = {
|
|
97
|
+
'architecture': 'Unet',
|
|
98
|
+
'channels': channels,
|
|
99
|
+
'strides': strides,
|
|
100
|
+
'num_res_units': res_units,
|
|
101
|
+
'dropout': 0.1,
|
|
102
|
+
'dim_in': dim_in
|
|
103
|
+
}
|
|
104
|
+
return model_config
|
|
105
|
+
|
|
106
|
+
@click.command('train', help="Train 3D CNN U-Net models")
|
|
107
|
+
# Training Arguments (applied in reverse order)
|
|
108
|
+
@common.train_parameters(octopi=False)
|
|
109
|
+
# UNet-Model Arguments
|
|
110
|
+
@common.model_parameters(octopi=False)
|
|
111
|
+
# Fine-Tuning Arguments
|
|
112
|
+
@click.option('-mw', '--model-weights', type=click.Path(exists=True), default=None,
|
|
113
|
+
help="Path to the model weights file (typically used for fine-tuning)")
|
|
114
|
+
@click.option('-mc', '--model-config', type=click.Path(exists=True), default=None,
|
|
115
|
+
help="Path to the model configuration file (typically used for fine-tuning)")
|
|
116
|
+
# Input Arguments
|
|
117
|
+
@click.option('-split', '--data-split', type=str, default='0.8',
|
|
118
|
+
help="Data split ratios. Either a single value (e.g., '0.8' for 80/20/0 split) or two comma-separated values (e.g., '0.7,0.1' for 70/10/20 split)")
|
|
119
|
+
@click.option('-vruns', "--validateRunIDs", type=str, default=None,
|
|
120
|
+
callback=lambda ctx, param, value: parsers.parse_list(value) if value else None,
|
|
121
|
+
help="List of validation run IDs, e.g., run4,run5,run6")
|
|
122
|
+
@click.option('-truns', "--trainRunIDs", type=str, default=None,
|
|
123
|
+
callback=lambda ctx, param, value: parsers.parse_list(value) if value else None,
|
|
124
|
+
help="List of training run IDs, e.g., run1,run2,run3")
|
|
125
|
+
@click.option('-alg',"--tomo-alg", type=str, default='wbp',
|
|
126
|
+
help="Tomogram algorithm used for training")
|
|
127
|
+
@click.option('-tinfo', "--target-info", type=str, default="targets,octopi,1",
|
|
128
|
+
callback=lambda ctx, param, value: parsers.parse_target(value),
|
|
129
|
+
help="Target information, e.g., 'name' or 'name,user_id,session_id'. Default is 'targets,octopi,1'.")
|
|
130
|
+
@common.config_parameters(single_config=False)
|
|
131
|
+
def cli(config, voxel_size, target_info, tomo_alg, trainrunids, validaterunids, data_split,
|
|
132
|
+
model_config, model_weights,
|
|
133
|
+
channels, strides, res_units, dim_in,
|
|
134
|
+
num_epochs, val_interval, tomo_batch_size, best_metric,
|
|
135
|
+
num_tomo_crops, lr, tversky_alpha, model_save_path):
|
|
136
|
+
"""
|
|
137
|
+
CLI entry point for training models where results can either be saved to a local directory or a server with MLFlow.
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
run_train(config, voxel_size, target_info, tomo_alg, trainrunids, validaterunids, data_split,
|
|
141
|
+
model_config, model_weights,
|
|
142
|
+
channels, strides, res_units, dim_in,
|
|
143
|
+
num_epochs, val_interval, tomo_batch_size, best_metric,
|
|
144
|
+
num_tomo_crops, lr, tversky_alpha, model_save_path)
|
|
145
|
+
|
|
146
|
+
def run_train(config, voxel_size, target_info, tomo_alg, trainrunids, validaterunids, data_split,
|
|
147
|
+
model_config, model_weights,
|
|
148
|
+
channels, strides, res_units, dim_in,
|
|
149
|
+
num_epochs, val_interval, tomo_batch_size, best_metric,
|
|
150
|
+
num_tomo_crops, lr, tversky_alpha, model_save_path):
|
|
151
|
+
"""
|
|
152
|
+
Run the training model.
|
|
153
|
+
"""
|
|
154
|
+
import octopi.utils.io as io
|
|
155
|
+
|
|
156
|
+
# Parse the CoPick configuration paths
|
|
157
|
+
if len(config) > 1:
|
|
158
|
+
copick_configs = parsers.parse_copick_configs(config)
|
|
159
|
+
else:
|
|
160
|
+
copick_configs = config[0]
|
|
161
|
+
|
|
162
|
+
if model_config:
|
|
163
|
+
model_config_dict = io.load_yaml(model_config)
|
|
164
|
+
else:
|
|
165
|
+
model_config_dict = get_model_config(channels, strides, res_units, dim_in)
|
|
166
|
+
|
|
167
|
+
# Call the training function
|
|
168
|
+
train_model(
|
|
169
|
+
copick_config_path=copick_configs,
|
|
170
|
+
target_info=target_info,
|
|
171
|
+
tomo_algorithm=tomo_alg,
|
|
172
|
+
voxel_size=voxel_size,
|
|
173
|
+
model_config=model_config_dict,
|
|
174
|
+
model_weights=model_weights,
|
|
175
|
+
model_save_path=model_save_path,
|
|
176
|
+
num_tomo_crops=num_tomo_crops,
|
|
177
|
+
tomo_batch_size=tomo_batch_size,
|
|
178
|
+
lr=lr,
|
|
179
|
+
tversky_alpha=tversky_alpha,
|
|
180
|
+
num_epochs=num_epochs,
|
|
181
|
+
val_interval=val_interval,
|
|
182
|
+
best_metric=best_metric,
|
|
183
|
+
trainRunIDs=trainrunids,
|
|
184
|
+
validateRunIDs=validaterunids,
|
|
185
|
+
data_split=data_split
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
if __name__ == '__main__':
|
|
189
|
+
cli()
|
|
File without changes
|