octopi 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. octopi/__init__.py +7 -0
  2. octopi/datasets/__init__.py +0 -0
  3. octopi/datasets/augment.py +83 -0
  4. octopi/datasets/cached_datset.py +113 -0
  5. octopi/datasets/dataset.py +19 -0
  6. octopi/datasets/generators.py +458 -0
  7. octopi/datasets/io.py +200 -0
  8. octopi/datasets/mixup.py +49 -0
  9. octopi/datasets/multi_config_generator.py +252 -0
  10. octopi/entry_points/__init__.py +0 -0
  11. octopi/entry_points/common.py +119 -0
  12. octopi/entry_points/create_slurm_submission.py +251 -0
  13. octopi/entry_points/groups.py +152 -0
  14. octopi/entry_points/run_create_targets.py +234 -0
  15. octopi/entry_points/run_evaluate.py +99 -0
  16. octopi/entry_points/run_extract_mb_picks.py +191 -0
  17. octopi/entry_points/run_extract_midpoint.py +143 -0
  18. octopi/entry_points/run_localize.py +176 -0
  19. octopi/entry_points/run_optuna.py +161 -0
  20. octopi/entry_points/run_segment.py +154 -0
  21. octopi/entry_points/run_train.py +189 -0
  22. octopi/extract/__init__.py +0 -0
  23. octopi/extract/localize.py +217 -0
  24. octopi/extract/membranebound_extract.py +263 -0
  25. octopi/extract/midpoint_extract.py +193 -0
  26. octopi/main.py +33 -0
  27. octopi/models/AttentionUnet.py +56 -0
  28. octopi/models/MedNeXt.py +111 -0
  29. octopi/models/ModelTemplate.py +36 -0
  30. octopi/models/SegResNet.py +92 -0
  31. octopi/models/Unet.py +59 -0
  32. octopi/models/UnetPlusPlus.py +47 -0
  33. octopi/models/__init__.py +0 -0
  34. octopi/models/common.py +72 -0
  35. octopi/processing/__init__.py +0 -0
  36. octopi/processing/create_targets_from_picks.py +224 -0
  37. octopi/processing/downloader.py +138 -0
  38. octopi/processing/downsample.py +125 -0
  39. octopi/processing/evaluate.py +302 -0
  40. octopi/processing/importers.py +116 -0
  41. octopi/processing/segmentation_from_picks.py +167 -0
  42. octopi/pytorch/__init__.py +0 -0
  43. octopi/pytorch/hyper_search.py +244 -0
  44. octopi/pytorch/model_search_submitter.py +291 -0
  45. octopi/pytorch/segmentation.py +363 -0
  46. octopi/pytorch/segmentation_multigpu.py +162 -0
  47. octopi/pytorch/trainer.py +465 -0
  48. octopi/pytorch_lightning/__init__.py +0 -0
  49. octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
  50. octopi/pytorch_lightning/train_pl.py +244 -0
  51. octopi/utils/__init__.py +0 -0
  52. octopi/utils/config.py +57 -0
  53. octopi/utils/io.py +215 -0
  54. octopi/utils/losses.py +86 -0
  55. octopi/utils/parsers.py +162 -0
  56. octopi/utils/progress.py +78 -0
  57. octopi/utils/stopping_criteria.py +143 -0
  58. octopi/utils/submit_slurm.py +95 -0
  59. octopi/utils/visualization_tools.py +290 -0
  60. octopi/workflows.py +262 -0
  61. octopi-1.4.0.dist-info/METADATA +119 -0
  62. octopi-1.4.0.dist-info/RECORD +65 -0
  63. octopi-1.4.0.dist-info/WHEEL +4 -0
  64. octopi-1.4.0.dist-info/entry_points.txt +3 -0
  65. octopi-1.4.0.dist-info/licenses/LICENSE +41 -0
@@ -0,0 +1,244 @@
1
+ from monai.metrics import ConfusionMatrixMetric
2
+ from mlflow.tracking import MlflowClient
3
+ from octopi.pytorch import trainer
4
+ from octopi.models import common
5
+ import torch, mlflow, optuna, gc
6
+ from octopi.utils import io
7
+
8
+ class BayesianModelSearch:
9
+
10
+ def __init__(self, data_generator, model_type="Unet", parent_run_id=None, parent_run_name=None):
11
+ """
12
+ Class to handle model creation, training, and optimization.
13
+
14
+ Args:
15
+ data_generator (object): Data generator object containing dataset properties.
16
+ model_type (str): Type of model to build ("UNet", "AttentionUnet").
17
+ """
18
+ self.data_generator = data_generator
19
+ self.Nclasses = data_generator.Nclasses
20
+ self.device = None
21
+ self.model_type = model_type
22
+ self.model = None
23
+ self.loss_function = None
24
+ self.metrics_function = None
25
+ self.sampling = None
26
+ self.parent_run_id = parent_run_id
27
+
28
+ # Define results directory path
29
+ self.results_dir = f'explore_results_{self.model_type}'
30
+
31
+ def my_build_model(self, trial):
32
+ """Builds and initializes a model based on Optuna-suggested parameters."""
33
+
34
+ # Build the model
35
+ self.model_builder = common.get_model(self.model_type)
36
+ self.model_builder.bayesian_search(trial, self.Nclasses)
37
+ self.model = self.model_builder.model.to(self.device)
38
+ self.config = self.model_builder.config
39
+
40
+ # Define loss function
41
+ self.loss_function = common.get_loss_function(trial)
42
+
43
+ # Define metrics
44
+ self.metrics_function = ConfusionMatrixMetric(
45
+ include_background=False,
46
+ metric_name=["recall", "precision", "f1 score"],
47
+ reduction="none"
48
+ )
49
+
50
+ # Sample crop size and num_samples
51
+ self.sampling = {
52
+ 'crop_size': trial.suggest_int("crop_size", 48, 192, step=16),
53
+ 'num_samples': 8
54
+ }
55
+ self.config['dim_in'] = self.sampling['crop_size']
56
+
57
+ def _define_optimizer(self, trial):
58
+ # Define optimizer
59
+ # lr0 = trial.suggest_float("lr", 1e-4, 1e-3, log=True)
60
+ # wd = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True)
61
+ self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-3, weight_decay=1e-4)
62
+
63
+ def _train_model(self, trial, model_trainer, epochs, val_interval, crop_size, num_samples, best_metric):
64
+ """Handles model training and error handling."""
65
+ try:
66
+ results = model_trainer.train(
67
+ self.data_generator,
68
+ model_save_path=None,
69
+ crop_size=crop_size,
70
+ max_epochs=epochs,
71
+ val_interval=val_interval,
72
+ my_num_samples=num_samples,
73
+ best_metric=best_metric,
74
+ use_mlflow=True,
75
+ verbose=False,
76
+ trial=trial
77
+ )
78
+ return results['best_metric']
79
+
80
+ except torch.cuda.OutOfMemoryError:
81
+ print(f"[Trial Failed] OOM Error for model={model_trainer.model}, crop_size={crop_size}, num_samples={num_samples}")
82
+ trial.set_user_attr("out_of_memory", True)
83
+ raise optuna.TrialPruned()
84
+
85
+ except Exception as e:
86
+ print(f"[Trial Failed] Unexpected error: {e}")
87
+ trial.set_user_attr("error", str(e))
88
+ raise optuna.TrialPruned()
89
+
90
+ def objective(self, trial, epochs, device, val_interval=15, best_metric="avg_f1"):
91
+ """Runs the full training process for a given trial."""
92
+
93
+ # Set device
94
+ self.device = device
95
+
96
+ # Set a unique run name for each trial
97
+ trial_num = f"trial_{trial.number}"
98
+
99
+ # Start MLflow run
100
+ with mlflow.start_run(run_name=trial_num, nested=True):
101
+
102
+ # Build model
103
+ self.my_build_model(trial)
104
+
105
+ # Create trainer
106
+ self._define_optimizer(trial)
107
+ model_trainer = trainer.ModelTrainer(self.model, self.device, self.loss_function, self.metrics_function, self.optimizer)
108
+
109
+ # Train model and evaluate score
110
+ score = self._train_model(
111
+ trial, model_trainer, epochs, val_interval,
112
+ self.sampling['crop_size'], self.sampling['num_samples'],
113
+ best_metric)
114
+
115
+ # Log parameters and metrics
116
+ params = {
117
+ 'model': self.model_builder.get_model_parameters(),
118
+ 'optimizer': io.get_optimizer_parameters(model_trainer)
119
+ }
120
+ model_trainer.my_log_params(io.flatten_params(params))
121
+
122
+ # Explicitly set the parent run ID
123
+ mlflow.log_param("parent_run_id", self.parent_run_id)
124
+ mlflow.log_param("parent_run_name", self.parent_run_name)
125
+
126
+ # Save best model
127
+ self._save_best_model(trial, model_trainer, score)
128
+
129
+ # Cleanup
130
+ self.cleanup(model_trainer)
131
+
132
+ return score
133
+
134
+ def _setup_parallel_trial_run(self, trial, parent_run=None, gpu_count=1):
135
+ """Set up parallel MLflow runs and assign GPU for the trial."""
136
+
137
+ trial_num = f"trial_{trial.number}"
138
+
139
+ # Create a child run under the parent MLflow experiment
140
+ mlflow_client = MlflowClient()
141
+ self.trial_run = mlflow_client.create_run(
142
+ experiment_id=mlflow_client.get_run(parent_run.info.run_id).info.experiment_id,
143
+ tags={"mlflow.parentRunId": parent_run.info.run_id},
144
+ run_name=trial_num
145
+ )
146
+ target_run_id = self.trial_run.info.run_id
147
+ print(f"Logging trial {trial.number} data to MLflow run: {target_run_id}")
148
+
149
+ # Assign GPU device
150
+ if gpu_count > 1:
151
+ gpu_id = trial.number % gpu_count
152
+ device = torch.device(f"cuda:{gpu_id}")
153
+ torch.cuda.set_device(device)
154
+ else:
155
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
156
+
157
+ return device, mlflow_client, target_run_id
158
+
159
+ def multi_gpu_objective(self, parent_run, trial, epochs, val_interval=5, best_metric="avg_f1", gpu_count=1):
160
+ """
161
+ Trains model on multiple GPUs using a parent MLflow run.
162
+ """
163
+ self.device, self.client, self.target_run_id = self._setup_parallel_trial_run(trial, parent_run, gpu_count)
164
+
165
+ # Build model + trainer
166
+ self.my_build_model(trial)
167
+ self._define_optimizer(trial)
168
+ model_trainer = trainer.ModelTrainer(
169
+ self.model, self.device, self.loss_function,
170
+ self.metrics_function, self.optimizer
171
+ )
172
+ model_trainer.set_parallel_mlflow(self.client, self.target_run_id)
173
+
174
+ # Train Model, with error handling
175
+ score = None
176
+ run_status = "FAILED" # default; overwritten on success/prune
177
+ try:
178
+ score = self._train_model(
179
+ trial, model_trainer, epochs, val_interval,
180
+ self.sampling['crop_size'], self.sampling['num_samples'],
181
+ best_metric)
182
+ # Save best model and mark run as finished
183
+ run_status = "FINISHED"
184
+ self._save_best_model(trial, model_trainer, score)
185
+ return score
186
+ except optuna.TrialPruned:
187
+ run_status = "KILLED" # communicates early stop
188
+ raise
189
+ except Exception as e:
190
+ run_status = "FAILED"
191
+ print(f"[Trial Failed] Unexpected error: {e}")
192
+ trial.set_user_attr("error", str(e))
193
+ raise optuna.TrialPruned()
194
+ finally:
195
+ self.cleanup(model_trainer)
196
+ model_trainer.my_log_params({"parent_run_name": parent_run.info.run_name})
197
+ try:
198
+ if self.client is not None and self.target_run_id is not None:
199
+ self.client.set_terminated(self.target_run_id, status=run_status)
200
+ except Exception as e:
201
+ print(f"[Cleanup Failed] Unexpected error: {e}")
202
+
203
+ def _save_best_model(self, trial, model_trainer, score):
204
+ """Saves the best model if it improves upon previous scores."""
205
+ best_score_so_far = self.get_best_score(trial)
206
+ if score > best_score_so_far:
207
+ torch.save(model_trainer.model_weights, f'{self.results_dir}/best_model.pth')
208
+ io.save_parameters_to_yaml(self.model_builder, model_trainer, self.data_generator,
209
+ f'{self.results_dir}/model_config.yaml')
210
+
211
+ def get_best_score(self, trial):
212
+ """Retrieve the best score from the trial."""
213
+ try:
214
+ return trial.study.best_value
215
+ except ValueError:
216
+ return -float('inf')
217
+
218
+ def cleanup(self, model_trainer):
219
+ """Handles cleanup of resources."""
220
+
221
+ # Log training parameters
222
+ params = {
223
+ 'model': self.model_builder.get_model_parameters(),
224
+ 'optimizer': io.get_optimizer_parameters(model_trainer)
225
+ }
226
+ model_trainer.my_log_params(io.flatten_params(params))
227
+
228
+ # Delete the trainer and optimizer objects
229
+ del model_trainer, self.optimizer
230
+
231
+ # If the model object holds GPU memory, delete it explicitly and set it to None
232
+ if hasattr(self, "model"):
233
+ del self.model
234
+ self.model = None
235
+
236
+ # Optional: If your model_builder or other objects hold GPU references, delete them too
237
+ if hasattr(self, "model_builder"):
238
+ del self.model_builder
239
+ self.model_builder = None
240
+
241
+ # Clear the CUDA cache and force garbage collection
242
+ torch.cuda.empty_cache()
243
+ gc.collect()
244
+
@@ -0,0 +1,291 @@
1
+ from octopi.datasets import generators, multi_config_generator
2
+ from octopi.utils import config, parsers
3
+ from octopi.pytorch import hyper_search
4
+ import torch, mlflow, optuna
5
+ from typing import List
6
+ import pandas as pd
7
+
8
+ class ModelSearchSubmit:
9
+ def __init__(
10
+ self,
11
+ copick_config: str,
12
+ target_name: str,
13
+ target_user_id: str,
14
+ target_session_id: str,
15
+ tomo_algorithm: str,
16
+ voxel_size: float,
17
+ model_type: str,
18
+ best_metric: str = 'avg_f1',
19
+ num_epochs: int = 1000,
20
+ num_trials: int = 100,
21
+ data_split: str = 0.8,
22
+ random_seed: int = 42,
23
+ val_interval: int = 10,
24
+ tomo_batch_size: int = 15,
25
+ trainRunIDs: List[str] = None,
26
+ validateRunIDs: List[str] = None,
27
+ mlflow_experiment_name: str = 'explore',
28
+ ):
29
+ """
30
+ Initialize the ModelSearch class for architecture search with Optuna.
31
+
32
+ Args:
33
+ copick_config (str or dict): Path to the CoPick configuration file or a dictionary for multi-config training.
34
+ target_name (str): Name of the target for segmentation.
35
+ target_user_id (str): Optional user ID for tracking.
36
+ target_session_id (str): Optional session ID for tracking.
37
+ tomo_algorithm (str): Tomogram algorithm to use.
38
+ voxel_size (float): Voxel size for tomograms.
39
+ Nclass (int): Number of prediction classes.
40
+ model_type (str): Type of model to use.
41
+ mlflow_experiment_name (str): MLflow experiment name.
42
+ random_seed (int): Seed for reproducibility.
43
+ num_epochs (int): Number of epochs per trial.
44
+ num_trials (int): Number of trials for hyperparameter optimization.
45
+ tomo_batch_size (int): Batch size for tomogram loading.
46
+ best_metric (str): Metric to optimize.
47
+ val_interval (int): Validation interval.
48
+ trainRunIDs (List[str]): List of training run IDs.
49
+ validateRunIDs (List[str]): List of validation run IDs.
50
+ data_split (str): Data split ratios.
51
+ """
52
+
53
+ # Input parameters
54
+ self.copick_config = copick_config
55
+ self.target_name = target_name
56
+ self.target_user_id = target_user_id
57
+ self.target_session_id = target_session_id
58
+ self.tomo_algorithm = tomo_algorithm
59
+ self.voxel_size = voxel_size
60
+ self.model_type = model_type
61
+ self.mlflow_experiment_name = mlflow_experiment_name
62
+ self.random_seed = random_seed
63
+ self.num_epochs = num_epochs
64
+ self.num_trials = num_trials
65
+ self.tomo_batch_size = tomo_batch_size
66
+ self.best_metric = best_metric
67
+ self.val_interval = val_interval
68
+ self.trainRunIDs = trainRunIDs
69
+ self.validateRunIDs = validateRunIDs
70
+ self.data_split = data_split
71
+
72
+ # Data generator - will be initialized in _initialize_data_generator()
73
+ self.data_generator = None
74
+
75
+ # Set random seed for reproducibility
76
+ config.set_seed(self.random_seed)
77
+
78
+ # Initialize dataset generator
79
+ self._initialize_data_generator()
80
+
81
+ def _initialize_data_generator(self):
82
+ """Initializes the data generator for training and validation datasets."""
83
+ self._print_input_configs()
84
+
85
+ if isinstance(self.copick_config, dict):
86
+ self.data_generator = multi_config_generator.MultiConfigTrainLoaderManager(
87
+ self.copick_config,
88
+ self.target_name,
89
+ target_session_id=self.target_session_id,
90
+ target_user_id=self.target_user_id,
91
+ tomo_algorithm=self.tomo_algorithm,
92
+ voxel_size=self.voxel_size,
93
+ tomo_batch_size=self.tomo_batch_size
94
+ )
95
+ else:
96
+ self.data_generator = generators.TrainLoaderManager(
97
+ self.copick_config,
98
+ self.target_name,
99
+ target_session_id=self.target_session_id,
100
+ target_user_id=self.target_user_id,
101
+ tomo_algorithm=self.tomo_algorithm,
102
+ voxel_size=self.voxel_size,
103
+ tomo_batch_size=self.tomo_batch_size
104
+ )
105
+
106
+ # Split datasets into training and validation
107
+ ratios = parsers.parse_data_split(self.data_split)
108
+ self.data_generator.get_data_splits(
109
+ trainRunIDs=self.trainRunIDs,
110
+ validateRunIDs=self.validateRunIDs,
111
+ train_ratio = ratios[0], val_ratio = ratios[1], test_ratio = ratios[2],
112
+ create_test_dataset = False
113
+ )
114
+
115
+ # Get the reload frequency
116
+ self.data_generator.get_reload_frequency(self.num_epochs)
117
+ self.Nclass = self.data_generator.Nclasses
118
+
119
+ def _print_input_configs(self):
120
+ """Prints training configuration for debugging purposes."""
121
+ print(f'\nTraining with:')
122
+ if isinstance(self.copick_config, dict):
123
+ for session, config in self.copick_config.items():
124
+ print(f' {session}: {config}')
125
+ else:
126
+ print(f' {self.copick_config}')
127
+ print()
128
+
129
+ def run_model_search(self):
130
+ """Performs model architecture search using Optuna and MLflow."""
131
+
132
+ # Set up MLflow tracking
133
+ try:
134
+ tracking_uri = config.mlflow_setup()
135
+ mlflow.set_tracking_uri(tracking_uri)
136
+ except Exception as e:
137
+ print(f'Failed to set up MLflow tracking: {e}')
138
+ pass
139
+
140
+ mlflow.set_experiment(self.mlflow_experiment_name)
141
+
142
+ # Create a storage object with heartbeat configuration
143
+ storage_url = f"sqlite:///explore_results_{self.model_type}/trials.db"
144
+ self.storage = optuna.storages.RDBStorage(
145
+ url=storage_url,
146
+ heartbeat_interval=60, # Record heartbeat every minute
147
+ grace_period=600, # 10 minutes grace period
148
+ failed_trial_callback=optuna.storages.RetryFailedTrialCallback(max_retry=1)
149
+ )
150
+
151
+ # Detect GPU availability
152
+ gpu_count = torch.cuda.device_count()
153
+ print(f'Running Architecture Search Over {gpu_count} GPUs\n')
154
+
155
+ # Initialize model search object
156
+ if gpu_count > 1:
157
+ self._multi_gpu_optuna(gpu_count)
158
+ else:
159
+ model_search = hyper_search.BayesianModelSearch(self.data_generator, self.model_type)
160
+ self._single_gpu_optuna(model_search)
161
+
162
+ def _single_gpu_optuna(self, model_search):
163
+ """Runs Optuna optimization on a single GPU."""
164
+
165
+ with mlflow.start_run(nested=False) as parent_run:
166
+
167
+ model_search.parent_run_id = parent_run.info.run_id
168
+ model_search.parent_run_name = parent_run.info.run_name
169
+
170
+ # Log the experiment parameters
171
+ mlflow.log_params({"random_seed": self.random_seed})
172
+ mlflow.log_params(self.data_generator.get_dataloader_parameters())
173
+ mlflow.log_params({"parent_run_name": parent_run.info.run_name})
174
+
175
+ # Determine device
176
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
177
+
178
+ # Create the study and run the optimization
179
+ study = self.get_optuna_study()
180
+ study.optimize(
181
+ lambda trial: model_search.objective(
182
+ trial, self.num_epochs, device,
183
+ val_interval=self.val_interval,
184
+ best_metric=self.best_metric,
185
+ ),
186
+ n_trials=self.num_trials
187
+ )
188
+
189
+ # Save contour plot
190
+ self.save_contour_plot_as_png(study)
191
+
192
+ print(f"Best trial: {study.best_trial.value}")
193
+ print(f"Best params: {study.best_params}")
194
+
195
+ def _multi_gpu_optuna(self, gpu_count):
196
+ """Runs Optuna optimization on multiple GPUs."""
197
+ with mlflow.start_run() as parent_run:
198
+
199
+ # Log the experiment parameters
200
+ mlflow.log_params({"random_seed": self.random_seed})
201
+ mlflow.log_params(self.data_generator.get_dataloader_parameters())
202
+ mlflow.log_params({"parent_run_name": parent_run.info.run_name})
203
+
204
+ # Run multi-GPU optimization
205
+ study = self.get_optuna_study()
206
+ study.optimize(
207
+ lambda trial: hyper_search.BayesianModelSearch(self.data_generator, self.model_type).multi_gpu_objective(
208
+ parent_run, trial,
209
+ self.num_epochs,
210
+ best_metric=self.best_metric,
211
+ val_interval=self.val_interval,
212
+ gpu_count=gpu_count
213
+ ),
214
+ n_trials=self.num_trials,
215
+ n_jobs=gpu_count
216
+ )
217
+
218
+ # Save contour Plot
219
+ self.save_contour_plot_as_png(study)
220
+
221
+ print(f"Best trial: {study.best_trial.value}")
222
+ print(f"Best params: {study.best_params}")
223
+
224
+ def _get_optuna_sampler(self):
225
+ """Returns Optuna's TPE sampler with default settings."""
226
+ return optuna.samplers.TPESampler(
227
+ n_startup_trials=10,
228
+ n_ei_candidates=24,
229
+ multivariate=True
230
+ )
231
+ # return optuna.samplers.BoTorchSampler(
232
+ # n_startup_trials=10,
233
+ # multivariate=True
234
+ # )
235
+
236
+ def get_optuna_study(self):
237
+ """Returns the Optuna study object."""
238
+ return optuna.create_study(
239
+ storage=self.storage,
240
+ direction="maximize",
241
+ sampler=self._get_optuna_sampler(),
242
+ load_if_exists=True,
243
+ pruner=self._get_optuna_pruner()
244
+ )
245
+
246
+ def _get_optuna_pruner(self):
247
+ """Returns Optuna's pruning strategy."""
248
+ return optuna.pruners.MedianPruner(
249
+ n_startup_trials=10, # let at least 10 full trials run before pruning
250
+ n_warmup_steps=300, # dont prune before 300 epochs/steps
251
+ interval_steps=self.val_interval # check each interval
252
+ )
253
+
254
+ def save_contour_plot_as_png(self, study):
255
+ """
256
+ Save the contour plot of hyperparameter interactions as a PNG,
257
+ automatically extracting parameter names from the study object.
258
+
259
+ Args:
260
+ study: The Optuna study object.
261
+ output_path: Path to save the PNG file.
262
+ """
263
+ # Extract all parameter names from the study trials
264
+ all_params = set()
265
+ for trial in study.trials:
266
+ all_params.update(trial.params.keys())
267
+ all_params = list(all_params) # Convert to a sorted list for consistency
268
+
269
+ # Generate the contour plot
270
+ fig = optuna.visualization.plot_contour(study, params=all_params)
271
+
272
+ # Adjust figure size and font size
273
+ fig.update_layout(
274
+ width=6000, height=6000, # Large figure size
275
+ font=dict(size=40) # Increase font size for better readability
276
+ )
277
+
278
+ # Save the plot as a PNG file
279
+ fig.write_image(f'explore_results_{self.model_type}/contour_plot.png', scale=1)
280
+
281
+ # Extract trial data
282
+ trials = [
283
+ {**trial.params, 'objective_value': trial.value}
284
+ for trial in study.trials if trial.state == optuna.trial.TrialState.COMPLETE
285
+ ]
286
+
287
+ # Convert to DataFrame
288
+ df = pd.DataFrame(trials)
289
+
290
+ # Save to CSV
291
+ df.to_csv(f"explore_results_{self.model_type}/optuna_results.csv", index=False)