octopi 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- octopi/__init__.py +7 -0
- octopi/datasets/__init__.py +0 -0
- octopi/datasets/augment.py +83 -0
- octopi/datasets/cached_datset.py +113 -0
- octopi/datasets/dataset.py +19 -0
- octopi/datasets/generators.py +458 -0
- octopi/datasets/io.py +200 -0
- octopi/datasets/mixup.py +49 -0
- octopi/datasets/multi_config_generator.py +252 -0
- octopi/entry_points/__init__.py +0 -0
- octopi/entry_points/common.py +119 -0
- octopi/entry_points/create_slurm_submission.py +251 -0
- octopi/entry_points/groups.py +152 -0
- octopi/entry_points/run_create_targets.py +234 -0
- octopi/entry_points/run_evaluate.py +99 -0
- octopi/entry_points/run_extract_mb_picks.py +191 -0
- octopi/entry_points/run_extract_midpoint.py +143 -0
- octopi/entry_points/run_localize.py +176 -0
- octopi/entry_points/run_optuna.py +161 -0
- octopi/entry_points/run_segment.py +154 -0
- octopi/entry_points/run_train.py +189 -0
- octopi/extract/__init__.py +0 -0
- octopi/extract/localize.py +217 -0
- octopi/extract/membranebound_extract.py +263 -0
- octopi/extract/midpoint_extract.py +193 -0
- octopi/main.py +33 -0
- octopi/models/AttentionUnet.py +56 -0
- octopi/models/MedNeXt.py +111 -0
- octopi/models/ModelTemplate.py +36 -0
- octopi/models/SegResNet.py +92 -0
- octopi/models/Unet.py +59 -0
- octopi/models/UnetPlusPlus.py +47 -0
- octopi/models/__init__.py +0 -0
- octopi/models/common.py +72 -0
- octopi/processing/__init__.py +0 -0
- octopi/processing/create_targets_from_picks.py +224 -0
- octopi/processing/downloader.py +138 -0
- octopi/processing/downsample.py +125 -0
- octopi/processing/evaluate.py +302 -0
- octopi/processing/importers.py +116 -0
- octopi/processing/segmentation_from_picks.py +167 -0
- octopi/pytorch/__init__.py +0 -0
- octopi/pytorch/hyper_search.py +244 -0
- octopi/pytorch/model_search_submitter.py +291 -0
- octopi/pytorch/segmentation.py +363 -0
- octopi/pytorch/segmentation_multigpu.py +162 -0
- octopi/pytorch/trainer.py +465 -0
- octopi/pytorch_lightning/__init__.py +0 -0
- octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
- octopi/pytorch_lightning/train_pl.py +244 -0
- octopi/utils/__init__.py +0 -0
- octopi/utils/config.py +57 -0
- octopi/utils/io.py +215 -0
- octopi/utils/losses.py +86 -0
- octopi/utils/parsers.py +162 -0
- octopi/utils/progress.py +78 -0
- octopi/utils/stopping_criteria.py +143 -0
- octopi/utils/submit_slurm.py +95 -0
- octopi/utils/visualization_tools.py +290 -0
- octopi/workflows.py +262 -0
- octopi-1.4.0.dist-info/METADATA +119 -0
- octopi-1.4.0.dist-info/RECORD +65 -0
- octopi-1.4.0.dist-info/WHEEL +4 -0
- octopi-1.4.0.dist-info/entry_points.txt +3 -0
- octopi-1.4.0.dist-info/licenses/LICENSE +41 -0
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
from monai.metrics import ConfusionMatrixMetric
|
|
2
|
+
from mlflow.tracking import MlflowClient
|
|
3
|
+
from octopi.pytorch import trainer
|
|
4
|
+
from octopi.models import common
|
|
5
|
+
import torch, mlflow, optuna, gc
|
|
6
|
+
from octopi.utils import io
|
|
7
|
+
|
|
8
|
+
class BayesianModelSearch:
|
|
9
|
+
|
|
10
|
+
def __init__(self, data_generator, model_type="Unet", parent_run_id=None, parent_run_name=None):
|
|
11
|
+
"""
|
|
12
|
+
Class to handle model creation, training, and optimization.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
data_generator (object): Data generator object containing dataset properties.
|
|
16
|
+
model_type (str): Type of model to build ("UNet", "AttentionUnet").
|
|
17
|
+
"""
|
|
18
|
+
self.data_generator = data_generator
|
|
19
|
+
self.Nclasses = data_generator.Nclasses
|
|
20
|
+
self.device = None
|
|
21
|
+
self.model_type = model_type
|
|
22
|
+
self.model = None
|
|
23
|
+
self.loss_function = None
|
|
24
|
+
self.metrics_function = None
|
|
25
|
+
self.sampling = None
|
|
26
|
+
self.parent_run_id = parent_run_id
|
|
27
|
+
|
|
28
|
+
# Define results directory path
|
|
29
|
+
self.results_dir = f'explore_results_{self.model_type}'
|
|
30
|
+
|
|
31
|
+
def my_build_model(self, trial):
|
|
32
|
+
"""Builds and initializes a model based on Optuna-suggested parameters."""
|
|
33
|
+
|
|
34
|
+
# Build the model
|
|
35
|
+
self.model_builder = common.get_model(self.model_type)
|
|
36
|
+
self.model_builder.bayesian_search(trial, self.Nclasses)
|
|
37
|
+
self.model = self.model_builder.model.to(self.device)
|
|
38
|
+
self.config = self.model_builder.config
|
|
39
|
+
|
|
40
|
+
# Define loss function
|
|
41
|
+
self.loss_function = common.get_loss_function(trial)
|
|
42
|
+
|
|
43
|
+
# Define metrics
|
|
44
|
+
self.metrics_function = ConfusionMatrixMetric(
|
|
45
|
+
include_background=False,
|
|
46
|
+
metric_name=["recall", "precision", "f1 score"],
|
|
47
|
+
reduction="none"
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# Sample crop size and num_samples
|
|
51
|
+
self.sampling = {
|
|
52
|
+
'crop_size': trial.suggest_int("crop_size", 48, 192, step=16),
|
|
53
|
+
'num_samples': 8
|
|
54
|
+
}
|
|
55
|
+
self.config['dim_in'] = self.sampling['crop_size']
|
|
56
|
+
|
|
57
|
+
def _define_optimizer(self, trial):
|
|
58
|
+
# Define optimizer
|
|
59
|
+
# lr0 = trial.suggest_float("lr", 1e-4, 1e-3, log=True)
|
|
60
|
+
# wd = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True)
|
|
61
|
+
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-3, weight_decay=1e-4)
|
|
62
|
+
|
|
63
|
+
def _train_model(self, trial, model_trainer, epochs, val_interval, crop_size, num_samples, best_metric):
|
|
64
|
+
"""Handles model training and error handling."""
|
|
65
|
+
try:
|
|
66
|
+
results = model_trainer.train(
|
|
67
|
+
self.data_generator,
|
|
68
|
+
model_save_path=None,
|
|
69
|
+
crop_size=crop_size,
|
|
70
|
+
max_epochs=epochs,
|
|
71
|
+
val_interval=val_interval,
|
|
72
|
+
my_num_samples=num_samples,
|
|
73
|
+
best_metric=best_metric,
|
|
74
|
+
use_mlflow=True,
|
|
75
|
+
verbose=False,
|
|
76
|
+
trial=trial
|
|
77
|
+
)
|
|
78
|
+
return results['best_metric']
|
|
79
|
+
|
|
80
|
+
except torch.cuda.OutOfMemoryError:
|
|
81
|
+
print(f"[Trial Failed] OOM Error for model={model_trainer.model}, crop_size={crop_size}, num_samples={num_samples}")
|
|
82
|
+
trial.set_user_attr("out_of_memory", True)
|
|
83
|
+
raise optuna.TrialPruned()
|
|
84
|
+
|
|
85
|
+
except Exception as e:
|
|
86
|
+
print(f"[Trial Failed] Unexpected error: {e}")
|
|
87
|
+
trial.set_user_attr("error", str(e))
|
|
88
|
+
raise optuna.TrialPruned()
|
|
89
|
+
|
|
90
|
+
def objective(self, trial, epochs, device, val_interval=15, best_metric="avg_f1"):
|
|
91
|
+
"""Runs the full training process for a given trial."""
|
|
92
|
+
|
|
93
|
+
# Set device
|
|
94
|
+
self.device = device
|
|
95
|
+
|
|
96
|
+
# Set a unique run name for each trial
|
|
97
|
+
trial_num = f"trial_{trial.number}"
|
|
98
|
+
|
|
99
|
+
# Start MLflow run
|
|
100
|
+
with mlflow.start_run(run_name=trial_num, nested=True):
|
|
101
|
+
|
|
102
|
+
# Build model
|
|
103
|
+
self.my_build_model(trial)
|
|
104
|
+
|
|
105
|
+
# Create trainer
|
|
106
|
+
self._define_optimizer(trial)
|
|
107
|
+
model_trainer = trainer.ModelTrainer(self.model, self.device, self.loss_function, self.metrics_function, self.optimizer)
|
|
108
|
+
|
|
109
|
+
# Train model and evaluate score
|
|
110
|
+
score = self._train_model(
|
|
111
|
+
trial, model_trainer, epochs, val_interval,
|
|
112
|
+
self.sampling['crop_size'], self.sampling['num_samples'],
|
|
113
|
+
best_metric)
|
|
114
|
+
|
|
115
|
+
# Log parameters and metrics
|
|
116
|
+
params = {
|
|
117
|
+
'model': self.model_builder.get_model_parameters(),
|
|
118
|
+
'optimizer': io.get_optimizer_parameters(model_trainer)
|
|
119
|
+
}
|
|
120
|
+
model_trainer.my_log_params(io.flatten_params(params))
|
|
121
|
+
|
|
122
|
+
# Explicitly set the parent run ID
|
|
123
|
+
mlflow.log_param("parent_run_id", self.parent_run_id)
|
|
124
|
+
mlflow.log_param("parent_run_name", self.parent_run_name)
|
|
125
|
+
|
|
126
|
+
# Save best model
|
|
127
|
+
self._save_best_model(trial, model_trainer, score)
|
|
128
|
+
|
|
129
|
+
# Cleanup
|
|
130
|
+
self.cleanup(model_trainer)
|
|
131
|
+
|
|
132
|
+
return score
|
|
133
|
+
|
|
134
|
+
def _setup_parallel_trial_run(self, trial, parent_run=None, gpu_count=1):
|
|
135
|
+
"""Set up parallel MLflow runs and assign GPU for the trial."""
|
|
136
|
+
|
|
137
|
+
trial_num = f"trial_{trial.number}"
|
|
138
|
+
|
|
139
|
+
# Create a child run under the parent MLflow experiment
|
|
140
|
+
mlflow_client = MlflowClient()
|
|
141
|
+
self.trial_run = mlflow_client.create_run(
|
|
142
|
+
experiment_id=mlflow_client.get_run(parent_run.info.run_id).info.experiment_id,
|
|
143
|
+
tags={"mlflow.parentRunId": parent_run.info.run_id},
|
|
144
|
+
run_name=trial_num
|
|
145
|
+
)
|
|
146
|
+
target_run_id = self.trial_run.info.run_id
|
|
147
|
+
print(f"Logging trial {trial.number} data to MLflow run: {target_run_id}")
|
|
148
|
+
|
|
149
|
+
# Assign GPU device
|
|
150
|
+
if gpu_count > 1:
|
|
151
|
+
gpu_id = trial.number % gpu_count
|
|
152
|
+
device = torch.device(f"cuda:{gpu_id}")
|
|
153
|
+
torch.cuda.set_device(device)
|
|
154
|
+
else:
|
|
155
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
156
|
+
|
|
157
|
+
return device, mlflow_client, target_run_id
|
|
158
|
+
|
|
159
|
+
def multi_gpu_objective(self, parent_run, trial, epochs, val_interval=5, best_metric="avg_f1", gpu_count=1):
|
|
160
|
+
"""
|
|
161
|
+
Trains model on multiple GPUs using a parent MLflow run.
|
|
162
|
+
"""
|
|
163
|
+
self.device, self.client, self.target_run_id = self._setup_parallel_trial_run(trial, parent_run, gpu_count)
|
|
164
|
+
|
|
165
|
+
# Build model + trainer
|
|
166
|
+
self.my_build_model(trial)
|
|
167
|
+
self._define_optimizer(trial)
|
|
168
|
+
model_trainer = trainer.ModelTrainer(
|
|
169
|
+
self.model, self.device, self.loss_function,
|
|
170
|
+
self.metrics_function, self.optimizer
|
|
171
|
+
)
|
|
172
|
+
model_trainer.set_parallel_mlflow(self.client, self.target_run_id)
|
|
173
|
+
|
|
174
|
+
# Train Model, with error handling
|
|
175
|
+
score = None
|
|
176
|
+
run_status = "FAILED" # default; overwritten on success/prune
|
|
177
|
+
try:
|
|
178
|
+
score = self._train_model(
|
|
179
|
+
trial, model_trainer, epochs, val_interval,
|
|
180
|
+
self.sampling['crop_size'], self.sampling['num_samples'],
|
|
181
|
+
best_metric)
|
|
182
|
+
# Save best model and mark run as finished
|
|
183
|
+
run_status = "FINISHED"
|
|
184
|
+
self._save_best_model(trial, model_trainer, score)
|
|
185
|
+
return score
|
|
186
|
+
except optuna.TrialPruned:
|
|
187
|
+
run_status = "KILLED" # communicates early stop
|
|
188
|
+
raise
|
|
189
|
+
except Exception as e:
|
|
190
|
+
run_status = "FAILED"
|
|
191
|
+
print(f"[Trial Failed] Unexpected error: {e}")
|
|
192
|
+
trial.set_user_attr("error", str(e))
|
|
193
|
+
raise optuna.TrialPruned()
|
|
194
|
+
finally:
|
|
195
|
+
self.cleanup(model_trainer)
|
|
196
|
+
model_trainer.my_log_params({"parent_run_name": parent_run.info.run_name})
|
|
197
|
+
try:
|
|
198
|
+
if self.client is not None and self.target_run_id is not None:
|
|
199
|
+
self.client.set_terminated(self.target_run_id, status=run_status)
|
|
200
|
+
except Exception as e:
|
|
201
|
+
print(f"[Cleanup Failed] Unexpected error: {e}")
|
|
202
|
+
|
|
203
|
+
def _save_best_model(self, trial, model_trainer, score):
|
|
204
|
+
"""Saves the best model if it improves upon previous scores."""
|
|
205
|
+
best_score_so_far = self.get_best_score(trial)
|
|
206
|
+
if score > best_score_so_far:
|
|
207
|
+
torch.save(model_trainer.model_weights, f'{self.results_dir}/best_model.pth')
|
|
208
|
+
io.save_parameters_to_yaml(self.model_builder, model_trainer, self.data_generator,
|
|
209
|
+
f'{self.results_dir}/model_config.yaml')
|
|
210
|
+
|
|
211
|
+
def get_best_score(self, trial):
|
|
212
|
+
"""Retrieve the best score from the trial."""
|
|
213
|
+
try:
|
|
214
|
+
return trial.study.best_value
|
|
215
|
+
except ValueError:
|
|
216
|
+
return -float('inf')
|
|
217
|
+
|
|
218
|
+
def cleanup(self, model_trainer):
|
|
219
|
+
"""Handles cleanup of resources."""
|
|
220
|
+
|
|
221
|
+
# Log training parameters
|
|
222
|
+
params = {
|
|
223
|
+
'model': self.model_builder.get_model_parameters(),
|
|
224
|
+
'optimizer': io.get_optimizer_parameters(model_trainer)
|
|
225
|
+
}
|
|
226
|
+
model_trainer.my_log_params(io.flatten_params(params))
|
|
227
|
+
|
|
228
|
+
# Delete the trainer and optimizer objects
|
|
229
|
+
del model_trainer, self.optimizer
|
|
230
|
+
|
|
231
|
+
# If the model object holds GPU memory, delete it explicitly and set it to None
|
|
232
|
+
if hasattr(self, "model"):
|
|
233
|
+
del self.model
|
|
234
|
+
self.model = None
|
|
235
|
+
|
|
236
|
+
# Optional: If your model_builder or other objects hold GPU references, delete them too
|
|
237
|
+
if hasattr(self, "model_builder"):
|
|
238
|
+
del self.model_builder
|
|
239
|
+
self.model_builder = None
|
|
240
|
+
|
|
241
|
+
# Clear the CUDA cache and force garbage collection
|
|
242
|
+
torch.cuda.empty_cache()
|
|
243
|
+
gc.collect()
|
|
244
|
+
|
|
@@ -0,0 +1,291 @@
|
|
|
1
|
+
from octopi.datasets import generators, multi_config_generator
|
|
2
|
+
from octopi.utils import config, parsers
|
|
3
|
+
from octopi.pytorch import hyper_search
|
|
4
|
+
import torch, mlflow, optuna
|
|
5
|
+
from typing import List
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
class ModelSearchSubmit:
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
copick_config: str,
|
|
12
|
+
target_name: str,
|
|
13
|
+
target_user_id: str,
|
|
14
|
+
target_session_id: str,
|
|
15
|
+
tomo_algorithm: str,
|
|
16
|
+
voxel_size: float,
|
|
17
|
+
model_type: str,
|
|
18
|
+
best_metric: str = 'avg_f1',
|
|
19
|
+
num_epochs: int = 1000,
|
|
20
|
+
num_trials: int = 100,
|
|
21
|
+
data_split: str = 0.8,
|
|
22
|
+
random_seed: int = 42,
|
|
23
|
+
val_interval: int = 10,
|
|
24
|
+
tomo_batch_size: int = 15,
|
|
25
|
+
trainRunIDs: List[str] = None,
|
|
26
|
+
validateRunIDs: List[str] = None,
|
|
27
|
+
mlflow_experiment_name: str = 'explore',
|
|
28
|
+
):
|
|
29
|
+
"""
|
|
30
|
+
Initialize the ModelSearch class for architecture search with Optuna.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
copick_config (str or dict): Path to the CoPick configuration file or a dictionary for multi-config training.
|
|
34
|
+
target_name (str): Name of the target for segmentation.
|
|
35
|
+
target_user_id (str): Optional user ID for tracking.
|
|
36
|
+
target_session_id (str): Optional session ID for tracking.
|
|
37
|
+
tomo_algorithm (str): Tomogram algorithm to use.
|
|
38
|
+
voxel_size (float): Voxel size for tomograms.
|
|
39
|
+
Nclass (int): Number of prediction classes.
|
|
40
|
+
model_type (str): Type of model to use.
|
|
41
|
+
mlflow_experiment_name (str): MLflow experiment name.
|
|
42
|
+
random_seed (int): Seed for reproducibility.
|
|
43
|
+
num_epochs (int): Number of epochs per trial.
|
|
44
|
+
num_trials (int): Number of trials for hyperparameter optimization.
|
|
45
|
+
tomo_batch_size (int): Batch size for tomogram loading.
|
|
46
|
+
best_metric (str): Metric to optimize.
|
|
47
|
+
val_interval (int): Validation interval.
|
|
48
|
+
trainRunIDs (List[str]): List of training run IDs.
|
|
49
|
+
validateRunIDs (List[str]): List of validation run IDs.
|
|
50
|
+
data_split (str): Data split ratios.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
# Input parameters
|
|
54
|
+
self.copick_config = copick_config
|
|
55
|
+
self.target_name = target_name
|
|
56
|
+
self.target_user_id = target_user_id
|
|
57
|
+
self.target_session_id = target_session_id
|
|
58
|
+
self.tomo_algorithm = tomo_algorithm
|
|
59
|
+
self.voxel_size = voxel_size
|
|
60
|
+
self.model_type = model_type
|
|
61
|
+
self.mlflow_experiment_name = mlflow_experiment_name
|
|
62
|
+
self.random_seed = random_seed
|
|
63
|
+
self.num_epochs = num_epochs
|
|
64
|
+
self.num_trials = num_trials
|
|
65
|
+
self.tomo_batch_size = tomo_batch_size
|
|
66
|
+
self.best_metric = best_metric
|
|
67
|
+
self.val_interval = val_interval
|
|
68
|
+
self.trainRunIDs = trainRunIDs
|
|
69
|
+
self.validateRunIDs = validateRunIDs
|
|
70
|
+
self.data_split = data_split
|
|
71
|
+
|
|
72
|
+
# Data generator - will be initialized in _initialize_data_generator()
|
|
73
|
+
self.data_generator = None
|
|
74
|
+
|
|
75
|
+
# Set random seed for reproducibility
|
|
76
|
+
config.set_seed(self.random_seed)
|
|
77
|
+
|
|
78
|
+
# Initialize dataset generator
|
|
79
|
+
self._initialize_data_generator()
|
|
80
|
+
|
|
81
|
+
def _initialize_data_generator(self):
|
|
82
|
+
"""Initializes the data generator for training and validation datasets."""
|
|
83
|
+
self._print_input_configs()
|
|
84
|
+
|
|
85
|
+
if isinstance(self.copick_config, dict):
|
|
86
|
+
self.data_generator = multi_config_generator.MultiConfigTrainLoaderManager(
|
|
87
|
+
self.copick_config,
|
|
88
|
+
self.target_name,
|
|
89
|
+
target_session_id=self.target_session_id,
|
|
90
|
+
target_user_id=self.target_user_id,
|
|
91
|
+
tomo_algorithm=self.tomo_algorithm,
|
|
92
|
+
voxel_size=self.voxel_size,
|
|
93
|
+
tomo_batch_size=self.tomo_batch_size
|
|
94
|
+
)
|
|
95
|
+
else:
|
|
96
|
+
self.data_generator = generators.TrainLoaderManager(
|
|
97
|
+
self.copick_config,
|
|
98
|
+
self.target_name,
|
|
99
|
+
target_session_id=self.target_session_id,
|
|
100
|
+
target_user_id=self.target_user_id,
|
|
101
|
+
tomo_algorithm=self.tomo_algorithm,
|
|
102
|
+
voxel_size=self.voxel_size,
|
|
103
|
+
tomo_batch_size=self.tomo_batch_size
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Split datasets into training and validation
|
|
107
|
+
ratios = parsers.parse_data_split(self.data_split)
|
|
108
|
+
self.data_generator.get_data_splits(
|
|
109
|
+
trainRunIDs=self.trainRunIDs,
|
|
110
|
+
validateRunIDs=self.validateRunIDs,
|
|
111
|
+
train_ratio = ratios[0], val_ratio = ratios[1], test_ratio = ratios[2],
|
|
112
|
+
create_test_dataset = False
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# Get the reload frequency
|
|
116
|
+
self.data_generator.get_reload_frequency(self.num_epochs)
|
|
117
|
+
self.Nclass = self.data_generator.Nclasses
|
|
118
|
+
|
|
119
|
+
def _print_input_configs(self):
|
|
120
|
+
"""Prints training configuration for debugging purposes."""
|
|
121
|
+
print(f'\nTraining with:')
|
|
122
|
+
if isinstance(self.copick_config, dict):
|
|
123
|
+
for session, config in self.copick_config.items():
|
|
124
|
+
print(f' {session}: {config}')
|
|
125
|
+
else:
|
|
126
|
+
print(f' {self.copick_config}')
|
|
127
|
+
print()
|
|
128
|
+
|
|
129
|
+
def run_model_search(self):
|
|
130
|
+
"""Performs model architecture search using Optuna and MLflow."""
|
|
131
|
+
|
|
132
|
+
# Set up MLflow tracking
|
|
133
|
+
try:
|
|
134
|
+
tracking_uri = config.mlflow_setup()
|
|
135
|
+
mlflow.set_tracking_uri(tracking_uri)
|
|
136
|
+
except Exception as e:
|
|
137
|
+
print(f'Failed to set up MLflow tracking: {e}')
|
|
138
|
+
pass
|
|
139
|
+
|
|
140
|
+
mlflow.set_experiment(self.mlflow_experiment_name)
|
|
141
|
+
|
|
142
|
+
# Create a storage object with heartbeat configuration
|
|
143
|
+
storage_url = f"sqlite:///explore_results_{self.model_type}/trials.db"
|
|
144
|
+
self.storage = optuna.storages.RDBStorage(
|
|
145
|
+
url=storage_url,
|
|
146
|
+
heartbeat_interval=60, # Record heartbeat every minute
|
|
147
|
+
grace_period=600, # 10 minutes grace period
|
|
148
|
+
failed_trial_callback=optuna.storages.RetryFailedTrialCallback(max_retry=1)
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# Detect GPU availability
|
|
152
|
+
gpu_count = torch.cuda.device_count()
|
|
153
|
+
print(f'Running Architecture Search Over {gpu_count} GPUs\n')
|
|
154
|
+
|
|
155
|
+
# Initialize model search object
|
|
156
|
+
if gpu_count > 1:
|
|
157
|
+
self._multi_gpu_optuna(gpu_count)
|
|
158
|
+
else:
|
|
159
|
+
model_search = hyper_search.BayesianModelSearch(self.data_generator, self.model_type)
|
|
160
|
+
self._single_gpu_optuna(model_search)
|
|
161
|
+
|
|
162
|
+
def _single_gpu_optuna(self, model_search):
|
|
163
|
+
"""Runs Optuna optimization on a single GPU."""
|
|
164
|
+
|
|
165
|
+
with mlflow.start_run(nested=False) as parent_run:
|
|
166
|
+
|
|
167
|
+
model_search.parent_run_id = parent_run.info.run_id
|
|
168
|
+
model_search.parent_run_name = parent_run.info.run_name
|
|
169
|
+
|
|
170
|
+
# Log the experiment parameters
|
|
171
|
+
mlflow.log_params({"random_seed": self.random_seed})
|
|
172
|
+
mlflow.log_params(self.data_generator.get_dataloader_parameters())
|
|
173
|
+
mlflow.log_params({"parent_run_name": parent_run.info.run_name})
|
|
174
|
+
|
|
175
|
+
# Determine device
|
|
176
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
177
|
+
|
|
178
|
+
# Create the study and run the optimization
|
|
179
|
+
study = self.get_optuna_study()
|
|
180
|
+
study.optimize(
|
|
181
|
+
lambda trial: model_search.objective(
|
|
182
|
+
trial, self.num_epochs, device,
|
|
183
|
+
val_interval=self.val_interval,
|
|
184
|
+
best_metric=self.best_metric,
|
|
185
|
+
),
|
|
186
|
+
n_trials=self.num_trials
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
# Save contour plot
|
|
190
|
+
self.save_contour_plot_as_png(study)
|
|
191
|
+
|
|
192
|
+
print(f"Best trial: {study.best_trial.value}")
|
|
193
|
+
print(f"Best params: {study.best_params}")
|
|
194
|
+
|
|
195
|
+
def _multi_gpu_optuna(self, gpu_count):
|
|
196
|
+
"""Runs Optuna optimization on multiple GPUs."""
|
|
197
|
+
with mlflow.start_run() as parent_run:
|
|
198
|
+
|
|
199
|
+
# Log the experiment parameters
|
|
200
|
+
mlflow.log_params({"random_seed": self.random_seed})
|
|
201
|
+
mlflow.log_params(self.data_generator.get_dataloader_parameters())
|
|
202
|
+
mlflow.log_params({"parent_run_name": parent_run.info.run_name})
|
|
203
|
+
|
|
204
|
+
# Run multi-GPU optimization
|
|
205
|
+
study = self.get_optuna_study()
|
|
206
|
+
study.optimize(
|
|
207
|
+
lambda trial: hyper_search.BayesianModelSearch(self.data_generator, self.model_type).multi_gpu_objective(
|
|
208
|
+
parent_run, trial,
|
|
209
|
+
self.num_epochs,
|
|
210
|
+
best_metric=self.best_metric,
|
|
211
|
+
val_interval=self.val_interval,
|
|
212
|
+
gpu_count=gpu_count
|
|
213
|
+
),
|
|
214
|
+
n_trials=self.num_trials,
|
|
215
|
+
n_jobs=gpu_count
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# Save contour Plot
|
|
219
|
+
self.save_contour_plot_as_png(study)
|
|
220
|
+
|
|
221
|
+
print(f"Best trial: {study.best_trial.value}")
|
|
222
|
+
print(f"Best params: {study.best_params}")
|
|
223
|
+
|
|
224
|
+
def _get_optuna_sampler(self):
|
|
225
|
+
"""Returns Optuna's TPE sampler with default settings."""
|
|
226
|
+
return optuna.samplers.TPESampler(
|
|
227
|
+
n_startup_trials=10,
|
|
228
|
+
n_ei_candidates=24,
|
|
229
|
+
multivariate=True
|
|
230
|
+
)
|
|
231
|
+
# return optuna.samplers.BoTorchSampler(
|
|
232
|
+
# n_startup_trials=10,
|
|
233
|
+
# multivariate=True
|
|
234
|
+
# )
|
|
235
|
+
|
|
236
|
+
def get_optuna_study(self):
|
|
237
|
+
"""Returns the Optuna study object."""
|
|
238
|
+
return optuna.create_study(
|
|
239
|
+
storage=self.storage,
|
|
240
|
+
direction="maximize",
|
|
241
|
+
sampler=self._get_optuna_sampler(),
|
|
242
|
+
load_if_exists=True,
|
|
243
|
+
pruner=self._get_optuna_pruner()
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
def _get_optuna_pruner(self):
|
|
247
|
+
"""Returns Optuna's pruning strategy."""
|
|
248
|
+
return optuna.pruners.MedianPruner(
|
|
249
|
+
n_startup_trials=10, # let at least 10 full trials run before pruning
|
|
250
|
+
n_warmup_steps=300, # dont prune before 300 epochs/steps
|
|
251
|
+
interval_steps=self.val_interval # check each interval
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
def save_contour_plot_as_png(self, study):
|
|
255
|
+
"""
|
|
256
|
+
Save the contour plot of hyperparameter interactions as a PNG,
|
|
257
|
+
automatically extracting parameter names from the study object.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
study: The Optuna study object.
|
|
261
|
+
output_path: Path to save the PNG file.
|
|
262
|
+
"""
|
|
263
|
+
# Extract all parameter names from the study trials
|
|
264
|
+
all_params = set()
|
|
265
|
+
for trial in study.trials:
|
|
266
|
+
all_params.update(trial.params.keys())
|
|
267
|
+
all_params = list(all_params) # Convert to a sorted list for consistency
|
|
268
|
+
|
|
269
|
+
# Generate the contour plot
|
|
270
|
+
fig = optuna.visualization.plot_contour(study, params=all_params)
|
|
271
|
+
|
|
272
|
+
# Adjust figure size and font size
|
|
273
|
+
fig.update_layout(
|
|
274
|
+
width=6000, height=6000, # Large figure size
|
|
275
|
+
font=dict(size=40) # Increase font size for better readability
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
# Save the plot as a PNG file
|
|
279
|
+
fig.write_image(f'explore_results_{self.model_type}/contour_plot.png', scale=1)
|
|
280
|
+
|
|
281
|
+
# Extract trial data
|
|
282
|
+
trials = [
|
|
283
|
+
{**trial.params, 'objective_value': trial.value}
|
|
284
|
+
for trial in study.trials if trial.state == optuna.trial.TrialState.COMPLETE
|
|
285
|
+
]
|
|
286
|
+
|
|
287
|
+
# Convert to DataFrame
|
|
288
|
+
df = pd.DataFrame(trials)
|
|
289
|
+
|
|
290
|
+
# Save to CSV
|
|
291
|
+
df.to_csv(f"explore_results_{self.model_type}/optuna_results.csv", index=False)
|