dragon-ml-toolbox 8.1.0__py3-none-any.whl → 9.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

Files changed (34) hide show
  1. {dragon_ml_toolbox-8.1.0.dist-info → dragon_ml_toolbox-9.0.0.dist-info}/METADATA +5 -1
  2. dragon_ml_toolbox-9.0.0.dist-info/RECORD +35 -0
  3. ml_tools/ETL_engineering.py +216 -81
  4. ml_tools/GUI_tools.py +5 -5
  5. ml_tools/MICE_imputation.py +12 -8
  6. ml_tools/ML_callbacks.py +6 -3
  7. ml_tools/ML_datasetmaster.py +37 -20
  8. ml_tools/ML_evaluation.py +4 -4
  9. ml_tools/ML_evaluation_multi.py +26 -17
  10. ml_tools/ML_inference.py +30 -23
  11. ml_tools/ML_models.py +14 -14
  12. ml_tools/ML_optimization.py +4 -3
  13. ml_tools/ML_scaler.py +7 -7
  14. ml_tools/ML_trainer.py +17 -15
  15. ml_tools/PSO_optimization.py +16 -8
  16. ml_tools/RNN_forecast.py +1 -1
  17. ml_tools/SQL.py +22 -13
  18. ml_tools/VIF_factor.py +7 -6
  19. ml_tools/_logger.py +105 -7
  20. ml_tools/custom_logger.py +12 -8
  21. ml_tools/data_exploration.py +20 -15
  22. ml_tools/ensemble_evaluation.py +10 -6
  23. ml_tools/ensemble_inference.py +18 -18
  24. ml_tools/ensemble_learning.py +8 -5
  25. ml_tools/handle_excel.py +15 -11
  26. ml_tools/optimization_tools.py +3 -4
  27. ml_tools/path_manager.py +21 -15
  28. ml_tools/utilities.py +35 -26
  29. dragon_ml_toolbox-8.1.0.dist-info/RECORD +0 -36
  30. ml_tools/_ML_optimization_multi.py +0 -231
  31. {dragon_ml_toolbox-8.1.0.dist-info → dragon_ml_toolbox-9.0.0.dist-info}/WHEEL +0 -0
  32. {dragon_ml_toolbox-8.1.0.dist-info → dragon_ml_toolbox-9.0.0.dist-info}/licenses/LICENSE +0 -0
  33. {dragon_ml_toolbox-8.1.0.dist-info → dragon_ml_toolbox-9.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  34. {dragon_ml_toolbox-8.1.0.dist-info → dragon_ml_toolbox-9.0.0.dist-info}/top_level.txt +0 -0
ml_tools/utilities.py CHANGED
@@ -76,11 +76,13 @@ def load_dataframe(
76
76
  df = pl.read_csv(path, infer_schema_length=1000)
77
77
 
78
78
  else:
79
- raise ValueError(f"Invalid kind '{kind}'. Must be one of 'pandas' or 'polars'.")
79
+ _LOGGER.error(f"Invalid kind '{kind}'. Must be one of 'pandas' or 'polars'.")
80
+ raise ValueError()
80
81
 
81
82
  # This check works for both pandas and polars DataFrames
82
83
  if df.shape[0] == 0:
83
- raise ValueError(f"DataFrame '{df_name}' loaded from '{path}' is empty.")
84
+ _LOGGER.error(f"DataFrame '{df_name}' loaded from '{path}' is empty.")
85
+ raise ValueError()
84
86
 
85
87
  if verbose:
86
88
  _LOGGER.info(f"💾 Loaded {kind.upper()} dataset: '{df_name}' with shape: {df.shape}")
@@ -162,13 +164,14 @@ def merge_dataframes(
162
164
  merged_df = pd.concat(dfs, axis=0)
163
165
 
164
166
  else:
165
- raise ValueError(f"Invalid merge direction: {direction}")
167
+ _LOGGER.error(f"Invalid merge direction: {direction}")
168
+ raise ValueError()
166
169
 
167
170
  if reset_index:
168
171
  merged_df = merged_df.reset_index(drop=True)
169
172
 
170
173
  if verbose:
171
- _LOGGER.info(f"Merged DataFrame shape: {merged_df.shape}")
174
+ _LOGGER.info(f"Merged DataFrame shape: {merged_df.shape}")
172
175
 
173
176
  return merged_df
174
177
 
@@ -187,7 +190,7 @@ def save_dataframe(df: Union[pd.DataFrame, pl.DataFrame], save_dir: Union[str,Pa
187
190
  """
188
191
  # This check works for both pandas and polars
189
192
  if df.shape[0] == 0:
190
- _LOGGER.warning(f"⚠️ Attempting to save an empty DataFrame: '{filename}'. Process Skipped.")
193
+ _LOGGER.warning(f"Attempting to save an empty DataFrame: '{filename}'. Process Skipped.")
191
194
  return
192
195
 
193
196
  # Create the directory if it doesn't exist
@@ -207,9 +210,10 @@ def save_dataframe(df: Union[pd.DataFrame, pl.DataFrame], save_dir: Union[str,Pa
207
210
  df.write_csv(output_path) # Polars defaults to utf8 and no index
208
211
  else:
209
212
  # This error handles cases where an unsupported type is passed
210
- raise TypeError(f"Unsupported DataFrame type: {type(df)}. Must be pandas or polars.")
213
+ _LOGGER.error(f"Unsupported DataFrame type: {type(df)}. Must be pandas or polars.")
214
+ raise TypeError()
211
215
 
212
- _LOGGER.info(f"Saved dataset: '{filename}' with shape: {df.shape}")
216
+ _LOGGER.info(f"Saved dataset: '{filename}' with shape: {df.shape}")
213
217
 
214
218
 
215
219
  def normalize_mixed_list(data: list, threshold: int = 2) -> list[float]:
@@ -243,7 +247,8 @@ def normalize_mixed_list(data: list, threshold: int = 2) -> list[float]:
243
247
 
244
248
  # Raise for negative values
245
249
  if any(x < 0 for x in float_list):
246
- raise ValueError("Negative values are not allowed in the input list.")
250
+ _LOGGER.error("Negative values are not allowed in the input list.")
251
+ raise ValueError()
247
252
 
248
253
  # Step 2: Compute log10 of non-zero values
249
254
  nonzero = [x for x in float_list if x > 0]
@@ -302,14 +307,16 @@ def threshold_binary_values(
302
307
  elif isinstance(input_array, (list, tuple)):
303
308
  array = np.array(input_array)
304
309
  else:
305
- raise TypeError("Unsupported input type")
310
+ _LOGGER.error("Unsupported input type")
311
+ raise TypeError()
306
312
 
307
313
  array = array.flatten()
308
314
  total = array.shape[0]
309
315
 
310
316
  bin_count = total if binary_values is None else binary_values
311
317
  if not (0 <= bin_count <= total):
312
- raise ValueError("binary_values must be between 0 and the total number of elements")
318
+ _LOGGER.error("'binary_values' must be between 0 and the total number of elements")
319
+ raise ValueError()
313
320
 
314
321
  if bin_count == 0:
315
322
  result = array
@@ -349,9 +356,15 @@ def threshold_binary_values_batch(
349
356
  np.ndarray
350
357
  Thresholded array, same shape as input.
351
358
  """
352
- assert input_array.ndim == 2, f"❌ Expected 2D array, got {input_array.ndim}D"
359
+ if input_array.ndim != 2:
360
+ _LOGGER.error(f"Expected 2D array, got {input_array.ndim}D array.")
361
+ raise AssertionError()
362
+
353
363
  batch_size, total_features = input_array.shape
354
- assert 0 <= binary_values <= total_features, "❌ binary_values out of valid range"
364
+
365
+ if not (0 <= binary_values <= total_features):
366
+ _LOGGER.error("'binary_values' out of valid range.")
367
+ raise AssertionError()
355
368
 
356
369
  if binary_values == 0:
357
370
  return input_array.copy()
@@ -380,15 +393,13 @@ def serialize_object(obj: Any, save_dir: Union[str,Path], filename: str, verbose
380
393
  full_path = save_path / sanitized_name
381
394
  joblib.dump(obj, full_path)
382
395
  except (IOError, OSError, TypeError, TerminatedWorkerError) as e:
383
- message = f"Failed to serialize object of type '{type(obj)}': {e}"
396
+ _LOGGER.error(f"Failed to serialize object of type '{type(obj)}'.")
384
397
  if raise_on_error:
385
- raise Exception(message)
386
- else:
387
- _LOGGER.warning(message)
398
+ raise e
388
399
  return None
389
400
  else:
390
401
  if verbose:
391
- _LOGGER.info(f"Object of type '{type(obj)}' saved to '{full_path}'")
402
+ _LOGGER.info(f"Object of type '{type(obj)}' saved to '{full_path}'")
392
403
  return None
393
404
 
394
405
 
@@ -407,15 +418,13 @@ def deserialize_object(filepath: Union[str,Path], verbose: bool=True, raise_on_e
407
418
  try:
408
419
  obj = joblib.load(true_filepath)
409
420
  except (IOError, OSError, EOFError, TypeError, ValueError) as e:
410
- message = f"Failed to deserialize object from '{true_filepath}': {e}"
421
+ _LOGGER.error(f"Failed to deserialize object from '{true_filepath}'.")
411
422
  if raise_on_error:
412
- raise Exception(message)
413
- else:
414
- _LOGGER.warning(message)
423
+ raise e
415
424
  return None
416
425
  else:
417
426
  if verbose:
418
- _LOGGER.info(f"Loaded object of type '{type(obj)}'")
427
+ _LOGGER.info(f"Loaded object of type '{type(obj)}'.")
419
428
  return obj
420
429
 
421
430
 
@@ -486,7 +495,8 @@ def train_dataset_orchestrator(list_of_dirs: list[Union[str,Path]],
486
495
  for dir in list_of_dirs:
487
496
  dir_path = make_fullpath(dir)
488
497
  if not dir_path.is_dir():
489
- raise IOError(f"'{dir}' is not a directory.")
498
+ _LOGGER.error(f"'{dir}' is not a directory.")
499
+ raise IOError()
490
500
  all_dir_paths.append(dir_path)
491
501
 
492
502
  # main loop
@@ -502,10 +512,10 @@ def train_dataset_orchestrator(list_of_dirs: list[Union[str,Path]],
502
512
  save_dataframe(df=df, save_dir=save_dir, filename=filename)
503
513
  total_saved += 1
504
514
  except Exception as e:
505
- _LOGGER.warning(f"⚠️ Failed to process file '{df_path}'. Reason: {e}")
515
+ _LOGGER.error(f"Failed to process file '{df_path}'. Reason: {e}")
506
516
  continue
507
517
 
508
- _LOGGER.info(f"{total_saved} single-target datasets were created.")
518
+ _LOGGER.info(f"{total_saved} single-target datasets were created.")
509
519
 
510
520
 
511
521
  def train_dataset_yielder(
@@ -530,6 +540,5 @@ def train_dataset_yielder(
530
540
  yield (df_features, df_target, feature_names, target_col)
531
541
 
532
542
 
533
-
534
543
  def info():
535
544
  _script_info(__all__)
@@ -1,36 +0,0 @@
1
- dragon_ml_toolbox-8.1.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
2
- dragon_ml_toolbox-8.1.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=lY4_rJPnLnMu7YBQaY-_iz1JRDcLdQzNCyeLAF1glJY,1837
3
- ml_tools/ETL_engineering.py,sha256=4wwZXi9_U7xfCY70jGBaKniOeZ0m75ppxWpQBd_DmLc,39369
4
- ml_tools/GUI_tools.py,sha256=n4ZZ5kEjwK5rkOCFJE41HeLFfjhpJVLUSzk9Kd9Kr_0,45410
5
- ml_tools/MICE_imputation.py,sha256=oFHg-OytOzPYTzBR_wIRHhP71cMn3aupDeT59ABsXlQ,11576
6
- ml_tools/ML_callbacks.py,sha256=noedVMmHZ72Odbg28zqx5wkhhvX2v-jXicKE_NCAiqU,13838
7
- ml_tools/ML_datasetmaster.py,sha256=tN-GBPEwXRWFBT8r8K0v9b3Bd77DhqSH5FkjDP6BHTw,28847
8
- ml_tools/ML_evaluation.py,sha256=BER5dOvSTySNzO92gm8tIpqJ5vT-s0iHMmaoly1uUH8,16018
9
- ml_tools/ML_evaluation_multi.py,sha256=uVtKGYWgOLv34Xj_jz6E_HAYzNb0HwRbMwA8oFZWpUk,12395
10
- ml_tools/ML_inference.py,sha256=hwtAdyDCE1xtqLgJgyOTAPck0eTmkOCJK1cM_IJSdck,22824
11
- ml_tools/ML_models.py,sha256=xZiSFh7S6eitl-VjjvNpsikojDvurK8n_ueLEh6_5pM,27979
12
- ml_tools/ML_optimization.py,sha256=GX-qZ2mCI3gWRCTP5w7lXrZpfGle3J_mE0O68seIoio,13475
13
- ml_tools/ML_scaler.py,sha256=pGkp1nUpeuoBvbq5hUkieQdxex6kNef1mEbeS_HUCJs,7471
14
- ml_tools/ML_trainer.py,sha256=6JSmEQaCPSo-S_5plNBTPw-SYgzZpyMNwiqpShJf7qU,23726
15
- ml_tools/PSO_optimization.py,sha256=9Y074d-B5h4Wvp9YPiy6KAeXM-Yv6Il3gWalKvOLVgo,22705
16
- ml_tools/RNN_forecast.py,sha256=2CyjBLSYYc3xLHxwLXUmP5Qv8AmV1OB_EndETNX1IBk,1956
17
- ml_tools/SQL.py,sha256=bkSTmMV4CtEqa67hApYWaRxTqwAlKIc5_b28P1bnDwg,10475
18
- ml_tools/VIF_factor.py,sha256=2nUMupfUoogf8o6ghoFZk_OwWhFXU0R3C9Gj0HOlI14,10415
19
- ml_tools/_ML_optimization_multi.py,sha256=DrNG3Vf1uUw-3CpYfXREgSGuR4dTpLWY1F3R9j-PYqQ,9816
20
- ml_tools/__init__.py,sha256=q0y9faQ6e17XCQ7eUiCZ1FJ4Bg5EQqLjZ9f_l5REUUY,41
21
- ml_tools/_logger.py,sha256=TpgYguxO-CWYqqgLW0tqFjtwZ58PE_W2OCfWNGZr0n0,1175
22
- ml_tools/_script_info.py,sha256=21r83LV3RubsNZ_RTEUON6RbDf7Mh4_udweNcvdF_Fk,212
23
- ml_tools/custom_logger.py,sha256=nyLRxaRxkqYOFdSjI0X2BWXB8C2IU18QfmqIFKqSedI,5820
24
- ml_tools/data_exploration.py,sha256=RuMHWagXrSQi1MzAMlYeBeVg7UxhVvEq8gJ9bIam2BM,27103
25
- ml_tools/ensemble_evaluation.py,sha256=wnqoTPg4WYWf2A8z5XT0eSlW4snEuLCXQVj88sZKzQ4,24683
26
- ml_tools/ensemble_inference.py,sha256=rtU7eUaQne615n2g7IHZCJI-OvrBCcjxbTkEIvtCGFQ,9414
27
- ml_tools/ensemble_learning.py,sha256=dAyFgSTyvxJWjc_enJ_8EUoWwiekBeoNyJNxVY-kcUU,21868
28
- ml_tools/handle_excel.py,sha256=J9iwIqMZemoxK49J5osSwp9Ge0h9YTKyYGbOm53hcno,13007
29
- ml_tools/keys.py,sha256=HtPG8-MWh89C32A7eIlfuuA-DLwkxGkoDfwR2TGN9CQ,1074
30
- ml_tools/optimization_tools.py,sha256=EL5tgNFwRo-82pbRE1CFVy9noNhULD7wprWuKadPheg,5090
31
- ml_tools/path_manager.py,sha256=Z8e7w3MPqQaN8xmTnKuXZS6CIW59BFwwqGhGc00sdp4,13692
32
- ml_tools/utilities.py,sha256=LqXXTovaHbA5AOKRk6Ru6DgAPAM0wPfYU70kUjYBryo,19231
33
- dragon_ml_toolbox-8.1.0.dist-info/METADATA,sha256=qGTl4__H1ZsbyJHtExcDt14i8ziWXpEy2WaRAELPmTI,6778
34
- dragon_ml_toolbox-8.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
35
- dragon_ml_toolbox-8.1.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
36
- dragon_ml_toolbox-8.1.0.dist-info/RECORD,,
@@ -1,231 +0,0 @@
1
- import pandas as pd
2
- import torch
3
- import numpy as np
4
- import evotorch
5
- from evotorch.algorithms import NSGA2
6
- from evotorch.logging import PandasLogger
7
- from typing import Literal, Union, Tuple, List, Optional, Any, Callable
8
- from pathlib import Path
9
- from tqdm.auto import trange
10
- from functools import partial
11
- from contextlib import nullcontext
12
- import matplotlib.pyplot as plt
13
- import seaborn as sns
14
-
15
- from .path_manager import make_fullpath, sanitize_filename
16
- from ._logger import _LOGGER
17
- from ._script_info import _script_info
18
- from .ML_inference import PyTorchInferenceHandlerMulti # Using the multi-target handler
19
- from .keys import PyTorchInferenceKeys
20
- from .utilities import threshold_binary_values, save_dataframe
21
- from .SQL import DatabaseManager # Added for SQL saving
22
-
23
- __all__ = [
24
- "create_multi_objective_problem",
25
- "run_multi_objective_optimization",
26
- "plot_pareto_front"
27
- ]
28
-
29
-
30
- def create_multi_objective_problem(
31
- inference_handler: PyTorchInferenceHandlerMulti,
32
- bounds: Tuple[List[float], List[float]],
33
- binary_features: int,
34
- objective_senses: Tuple[Literal["min", "max"], ...],
35
- algorithm: Literal["NSGA2"] = "NSGA2",
36
- population_size: int = 200,
37
- **searcher_kwargs
38
- ) -> Tuple[evotorch.Problem, Callable[[], Any]]:
39
- """
40
- Creates and configures an EvoTorch Problem and a Searcher for multi-objective optimization.
41
-
42
- This function sets up a problem where the goal is to optimize multiple conflicting
43
- objectives simultaneously, using an algorithm like NSGA2 to find the Pareto front.
44
-
45
- Args:
46
- inference_handler (PyTorchInferenceHandlerMulti): An initialized handler for the multi-target model.
47
- bounds (tuple[list[float], list[float]]): Lower and upper bounds for the solution features.
48
- binary_features (int): Number of binary features at the end of the feature vector.
49
- objective_senses (Tuple[Literal["min", "max"], ...]): A tuple specifying the optimization
50
- goal for each target (e.g., ("max", "min", "max")). The length of this tuple
51
- must match the number of outputs from the model.
52
- algorithm (str): The multi-objective search algorithm to use. Currently supports "NSGA2".
53
- population_size (int): The number of solutions in each generation.
54
- **searcher_kwargs: Additional keyword arguments for the search algorithm's constructor.
55
-
56
- Returns:
57
- A tuple containing the configured multi-objective Problem and the Searcher factory.
58
- """
59
- lower_bounds, upper_bounds = list(bounds[0]), list(bounds[1])
60
-
61
- if binary_features > 0:
62
- lower_bounds.extend([0.45] * binary_features)
63
- upper_bounds.extend([0.55] * binary_features)
64
-
65
- solution_length = len(lower_bounds)
66
- device = inference_handler.device
67
-
68
- def fitness_func(solution_tensor: torch.Tensor) -> torch.Tensor:
69
- """
70
- The fitness function for a multi-objective problem.
71
- It returns the entire output tensor from the model. EvoTorch handles the rest.
72
- """
73
- # The handler returns a tensor of shape [batch_size, num_targets]
74
- predictions = inference_handler.predict_batch(solution_tensor)[PyTorchInferenceKeys.PREDICTIONS]
75
- return predictions
76
-
77
- if algorithm == "NSGA2":
78
- problem = evotorch.Problem(
79
- objective_sense=objective_senses,
80
- objective_func=fitness_func,
81
- solution_length=solution_length,
82
- bounds=(lower_bounds, upper_bounds),
83
- device=device,
84
- vectorized=True,
85
- num_actors='max' # Use available CPU cores
86
- )
87
- SearcherClass = NSGA2
88
- if 'popsize' not in searcher_kwargs:
89
- searcher_kwargs['popsize'] = population_size
90
- else:
91
- raise ValueError(f"Unknown multi-objective algorithm '{algorithm}'.")
92
-
93
- searcher_factory = partial(SearcherClass, problem, **searcher_kwargs)
94
- return problem, searcher_factory
95
-
96
-
97
- def run_multi_objective_optimization(
98
- problem: evotorch.Problem,
99
- searcher_factory: Callable[[], Any],
100
- num_generations: int,
101
- run_name: str,
102
- binary_features: int,
103
- save_dir: Union[str, Path],
104
- feature_names: List[str],
105
- target_names: List[str],
106
- save_format: Literal['csv', 'sqlite', 'both'] = 'csv',
107
- verbose: bool = True
108
- ):
109
- """
110
- Runs the multi-objective evolutionary optimization process to find the Pareto front.
111
-
112
- This function executes a multi-objective algorithm (like NSGA2) and saves the
113
- entire set of non-dominated solutions (the Pareto front) to the specified format(s).
114
- It also generates and saves a plot of the Pareto front.
115
-
116
- Args:
117
- problem (evotorch.Problem): The configured multi-objective problem.
118
- searcher_factory (Callable): A factory function to generate a fresh searcher instance.
119
- num_generations (int): The number of generations to run the algorithm.
120
- run_name (str): A name for this optimization run, used for filenames/table names.
121
- binary_features (int): Number of binary features in the solution vector.
122
- save_dir (str | Path): The directory where the result files will be saved.
123
- feature_names (List[str]): Names of the solution features for labeling columns.
124
- target_names (List[str]): Names of the target objectives for labeling columns.
125
- save_format (str): The format to save results in ('csv', 'sqlite', or 'both').
126
- verbose (bool): If True, attaches a logger and saves the evolution history.
127
- """
128
- save_path = make_fullpath(save_dir, make=True, enforce="directory")
129
- sanitized_run_name = sanitize_filename(run_name)
130
-
131
- if len(target_names) != problem.num_objectives:
132
- raise ValueError("The number of `target_names` must match the number of objectives in the problem.")
133
-
134
- searcher = searcher_factory()
135
- _LOGGER.info(f"🤖 Starting multi-objective optimization with {searcher.__class__.__name__} for {num_generations} generations...")
136
-
137
- logger = PandasLogger(searcher) if verbose else None
138
- searcher.run(num_generations)
139
-
140
- pareto_front = searcher.status["pareto_front"]
141
- _LOGGER.info(f"✅ Optimization complete. Found {len(pareto_front)} non-dominated solutions.")
142
-
143
- solutions_np = pareto_front.values.cpu().numpy()
144
- objectives_np = pareto_front.evals.cpu().numpy()
145
-
146
- if binary_features > 0:
147
- solutions_np = threshold_binary_values(input_array=solutions_np, binary_values=binary_features)
148
-
149
- results_df = pd.DataFrame(solutions_np, columns=feature_names)
150
- objective_cols = []
151
- for i, name in enumerate(target_names):
152
- col_name = f"predicted_{name}"
153
- results_df[col_name] = objectives_np[:, i]
154
- objective_cols.append(col_name)
155
-
156
- # --- Saving Logic ---
157
- if save_format in ['csv', 'both']:
158
- csv_path = save_path / f"pareto_front_{sanitized_run_name}.csv"
159
- results_df.to_csv(csv_path, index=False)
160
- _LOGGER.info(f"📄 Pareto front data saved to '{csv_path.name}'")
161
-
162
- if save_format in ['sqlite', 'both']:
163
- db_path = save_path / "Optimization_Multi.db"
164
- with DatabaseManager(db_path) as db:
165
- db.insert_from_dataframe(
166
- table_name=sanitized_run_name,
167
- df=results_df,
168
- if_exists='replace'
169
- )
170
- _LOGGER.info(f"🗃️ Pareto front data saved to table '{sanitized_run_name}' in '{db_path.name}'")
171
-
172
- # --- Plotting Logic ---
173
- plot_pareto_front(
174
- results_df,
175
- objective_cols=objective_cols,
176
- save_path=save_path / f"pareto_plot_{sanitized_run_name}.svg"
177
- )
178
-
179
- if logger:
180
- log_df = logger.to_dataframe()
181
- save_dataframe(df=log_df, save_dir=save_path / "EvolutionLogs", filename=f"log_{sanitized_run_name}")
182
-
183
-
184
- def plot_pareto_front(results_df: pd.DataFrame, objective_cols: List[str], save_path: Path):
185
- """
186
- Generates and saves a plot of the Pareto front.
187
-
188
- - For 2 objectives, it creates a 2D scatter plot.
189
- - For 3 objectives, it creates a 3D scatter plot.
190
- - For >3 objectives, it creates a scatter plot matrix (pairs plot).
191
-
192
- Args:
193
- results_df (pd.DataFrame): DataFrame containing the optimization results.
194
- objective_cols (List[str]): The names of the columns that hold the objective values.
195
- save_path (Path): The full path (including filename) to save the SVG plot.
196
- """
197
- num_objectives = len(objective_cols)
198
- _LOGGER.info(f"🎨 Generating Pareto front plot for {num_objectives} objectives...")
199
-
200
- plt.style.use('seaborn-v0_8-whitegrid')
201
-
202
- if num_objectives == 2:
203
- fig, ax = plt.subplots(figsize=(8, 6), dpi=120)
204
- ax.scatter(results_df[objective_cols[0]], results_df[objective_cols[1]], alpha=0.7, edgecolors='k')
205
- ax.set_xlabel(objective_cols[0])
206
- ax.set_ylabel(objective_cols[1])
207
- ax.set_title("Pareto Front (2D)")
208
-
209
- elif num_objectives == 3:
210
- fig = plt.figure(figsize=(9, 7), dpi=120)
211
- ax = fig.add_subplot(111, projection='3d')
212
- ax.scatter(results_df[objective_cols[0]], results_df[objective_cols[1]], results_df[objective_cols[2]], alpha=0.7, depthshade=True)
213
- ax.set_xlabel(objective_cols[0])
214
- ax.set_ylabel(objective_cols[1])
215
- ax.set_zlabel(objective_cols[2])
216
- ax.set_title("Pareto Front (3D)")
217
-
218
- else: # > 3 objectives
219
- _LOGGER.info(" -> More than 3 objectives found, generating a scatter plot matrix.")
220
- g = sns.pairplot(results_df[objective_cols], diag_kind="kde", plot_kws={'alpha': 0.6})
221
- g.fig.suptitle("Pareto Front (Pairs Plot)", y=1.02)
222
- plt.savefig(save_path, bbox_inches='tight')
223
- plt.close()
224
- _LOGGER.info(f"📊 Pareto plot saved to '{save_path.name}'")
225
- return
226
-
227
- plt.tight_layout()
228
- plt.savefig(save_path)
229
- plt.close()
230
- _LOGGER.info(f"📊 Pareto plot saved to '{save_path.name}'")
231
-