dragon-ml-toolbox 8.1.0__py3-none-any.whl → 9.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-8.1.0.dist-info → dragon_ml_toolbox-9.0.0.dist-info}/METADATA +5 -1
- dragon_ml_toolbox-9.0.0.dist-info/RECORD +35 -0
- ml_tools/ETL_engineering.py +216 -81
- ml_tools/GUI_tools.py +5 -5
- ml_tools/MICE_imputation.py +12 -8
- ml_tools/ML_callbacks.py +6 -3
- ml_tools/ML_datasetmaster.py +37 -20
- ml_tools/ML_evaluation.py +4 -4
- ml_tools/ML_evaluation_multi.py +26 -17
- ml_tools/ML_inference.py +30 -23
- ml_tools/ML_models.py +14 -14
- ml_tools/ML_optimization.py +4 -3
- ml_tools/ML_scaler.py +7 -7
- ml_tools/ML_trainer.py +17 -15
- ml_tools/PSO_optimization.py +16 -8
- ml_tools/RNN_forecast.py +1 -1
- ml_tools/SQL.py +22 -13
- ml_tools/VIF_factor.py +7 -6
- ml_tools/_logger.py +105 -7
- ml_tools/custom_logger.py +12 -8
- ml_tools/data_exploration.py +20 -15
- ml_tools/ensemble_evaluation.py +10 -6
- ml_tools/ensemble_inference.py +18 -18
- ml_tools/ensemble_learning.py +8 -5
- ml_tools/handle_excel.py +15 -11
- ml_tools/optimization_tools.py +3 -4
- ml_tools/path_manager.py +21 -15
- ml_tools/utilities.py +35 -26
- dragon_ml_toolbox-8.1.0.dist-info/RECORD +0 -36
- ml_tools/_ML_optimization_multi.py +0 -231
- {dragon_ml_toolbox-8.1.0.dist-info → dragon_ml_toolbox-9.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-8.1.0.dist-info → dragon_ml_toolbox-9.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-8.1.0.dist-info → dragon_ml_toolbox-9.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-8.1.0.dist-info → dragon_ml_toolbox-9.0.0.dist-info}/top_level.txt +0 -0
ml_tools/utilities.py
CHANGED
|
@@ -76,11 +76,13 @@ def load_dataframe(
|
|
|
76
76
|
df = pl.read_csv(path, infer_schema_length=1000)
|
|
77
77
|
|
|
78
78
|
else:
|
|
79
|
-
|
|
79
|
+
_LOGGER.error(f"Invalid kind '{kind}'. Must be one of 'pandas' or 'polars'.")
|
|
80
|
+
raise ValueError()
|
|
80
81
|
|
|
81
82
|
# This check works for both pandas and polars DataFrames
|
|
82
83
|
if df.shape[0] == 0:
|
|
83
|
-
|
|
84
|
+
_LOGGER.error(f"DataFrame '{df_name}' loaded from '{path}' is empty.")
|
|
85
|
+
raise ValueError()
|
|
84
86
|
|
|
85
87
|
if verbose:
|
|
86
88
|
_LOGGER.info(f"💾 Loaded {kind.upper()} dataset: '{df_name}' with shape: {df.shape}")
|
|
@@ -162,13 +164,14 @@ def merge_dataframes(
|
|
|
162
164
|
merged_df = pd.concat(dfs, axis=0)
|
|
163
165
|
|
|
164
166
|
else:
|
|
165
|
-
|
|
167
|
+
_LOGGER.error(f"Invalid merge direction: {direction}")
|
|
168
|
+
raise ValueError()
|
|
166
169
|
|
|
167
170
|
if reset_index:
|
|
168
171
|
merged_df = merged_df.reset_index(drop=True)
|
|
169
172
|
|
|
170
173
|
if verbose:
|
|
171
|
-
_LOGGER.info(f"
|
|
174
|
+
_LOGGER.info(f"Merged DataFrame shape: {merged_df.shape}")
|
|
172
175
|
|
|
173
176
|
return merged_df
|
|
174
177
|
|
|
@@ -187,7 +190,7 @@ def save_dataframe(df: Union[pd.DataFrame, pl.DataFrame], save_dir: Union[str,Pa
|
|
|
187
190
|
"""
|
|
188
191
|
# This check works for both pandas and polars
|
|
189
192
|
if df.shape[0] == 0:
|
|
190
|
-
_LOGGER.warning(f"
|
|
193
|
+
_LOGGER.warning(f"Attempting to save an empty DataFrame: '{filename}'. Process Skipped.")
|
|
191
194
|
return
|
|
192
195
|
|
|
193
196
|
# Create the directory if it doesn't exist
|
|
@@ -207,9 +210,10 @@ def save_dataframe(df: Union[pd.DataFrame, pl.DataFrame], save_dir: Union[str,Pa
|
|
|
207
210
|
df.write_csv(output_path) # Polars defaults to utf8 and no index
|
|
208
211
|
else:
|
|
209
212
|
# This error handles cases where an unsupported type is passed
|
|
210
|
-
|
|
213
|
+
_LOGGER.error(f"Unsupported DataFrame type: {type(df)}. Must be pandas or polars.")
|
|
214
|
+
raise TypeError()
|
|
211
215
|
|
|
212
|
-
_LOGGER.info(f"
|
|
216
|
+
_LOGGER.info(f"Saved dataset: '{filename}' with shape: {df.shape}")
|
|
213
217
|
|
|
214
218
|
|
|
215
219
|
def normalize_mixed_list(data: list, threshold: int = 2) -> list[float]:
|
|
@@ -243,7 +247,8 @@ def normalize_mixed_list(data: list, threshold: int = 2) -> list[float]:
|
|
|
243
247
|
|
|
244
248
|
# Raise for negative values
|
|
245
249
|
if any(x < 0 for x in float_list):
|
|
246
|
-
|
|
250
|
+
_LOGGER.error("Negative values are not allowed in the input list.")
|
|
251
|
+
raise ValueError()
|
|
247
252
|
|
|
248
253
|
# Step 2: Compute log10 of non-zero values
|
|
249
254
|
nonzero = [x for x in float_list if x > 0]
|
|
@@ -302,14 +307,16 @@ def threshold_binary_values(
|
|
|
302
307
|
elif isinstance(input_array, (list, tuple)):
|
|
303
308
|
array = np.array(input_array)
|
|
304
309
|
else:
|
|
305
|
-
|
|
310
|
+
_LOGGER.error("Unsupported input type")
|
|
311
|
+
raise TypeError()
|
|
306
312
|
|
|
307
313
|
array = array.flatten()
|
|
308
314
|
total = array.shape[0]
|
|
309
315
|
|
|
310
316
|
bin_count = total if binary_values is None else binary_values
|
|
311
317
|
if not (0 <= bin_count <= total):
|
|
312
|
-
|
|
318
|
+
_LOGGER.error("'binary_values' must be between 0 and the total number of elements")
|
|
319
|
+
raise ValueError()
|
|
313
320
|
|
|
314
321
|
if bin_count == 0:
|
|
315
322
|
result = array
|
|
@@ -349,9 +356,15 @@ def threshold_binary_values_batch(
|
|
|
349
356
|
np.ndarray
|
|
350
357
|
Thresholded array, same shape as input.
|
|
351
358
|
"""
|
|
352
|
-
|
|
359
|
+
if input_array.ndim != 2:
|
|
360
|
+
_LOGGER.error(f"Expected 2D array, got {input_array.ndim}D array.")
|
|
361
|
+
raise AssertionError()
|
|
362
|
+
|
|
353
363
|
batch_size, total_features = input_array.shape
|
|
354
|
-
|
|
364
|
+
|
|
365
|
+
if not (0 <= binary_values <= total_features):
|
|
366
|
+
_LOGGER.error("'binary_values' out of valid range.")
|
|
367
|
+
raise AssertionError()
|
|
355
368
|
|
|
356
369
|
if binary_values == 0:
|
|
357
370
|
return input_array.copy()
|
|
@@ -380,15 +393,13 @@ def serialize_object(obj: Any, save_dir: Union[str,Path], filename: str, verbose
|
|
|
380
393
|
full_path = save_path / sanitized_name
|
|
381
394
|
joblib.dump(obj, full_path)
|
|
382
395
|
except (IOError, OSError, TypeError, TerminatedWorkerError) as e:
|
|
383
|
-
|
|
396
|
+
_LOGGER.error(f"Failed to serialize object of type '{type(obj)}'.")
|
|
384
397
|
if raise_on_error:
|
|
385
|
-
raise
|
|
386
|
-
else:
|
|
387
|
-
_LOGGER.warning(message)
|
|
398
|
+
raise e
|
|
388
399
|
return None
|
|
389
400
|
else:
|
|
390
401
|
if verbose:
|
|
391
|
-
_LOGGER.info(f"
|
|
402
|
+
_LOGGER.info(f"Object of type '{type(obj)}' saved to '{full_path}'")
|
|
392
403
|
return None
|
|
393
404
|
|
|
394
405
|
|
|
@@ -407,15 +418,13 @@ def deserialize_object(filepath: Union[str,Path], verbose: bool=True, raise_on_e
|
|
|
407
418
|
try:
|
|
408
419
|
obj = joblib.load(true_filepath)
|
|
409
420
|
except (IOError, OSError, EOFError, TypeError, ValueError) as e:
|
|
410
|
-
|
|
421
|
+
_LOGGER.error(f"Failed to deserialize object from '{true_filepath}'.")
|
|
411
422
|
if raise_on_error:
|
|
412
|
-
raise
|
|
413
|
-
else:
|
|
414
|
-
_LOGGER.warning(message)
|
|
423
|
+
raise e
|
|
415
424
|
return None
|
|
416
425
|
else:
|
|
417
426
|
if verbose:
|
|
418
|
-
_LOGGER.info(f"
|
|
427
|
+
_LOGGER.info(f"Loaded object of type '{type(obj)}'.")
|
|
419
428
|
return obj
|
|
420
429
|
|
|
421
430
|
|
|
@@ -486,7 +495,8 @@ def train_dataset_orchestrator(list_of_dirs: list[Union[str,Path]],
|
|
|
486
495
|
for dir in list_of_dirs:
|
|
487
496
|
dir_path = make_fullpath(dir)
|
|
488
497
|
if not dir_path.is_dir():
|
|
489
|
-
|
|
498
|
+
_LOGGER.error(f"'{dir}' is not a directory.")
|
|
499
|
+
raise IOError()
|
|
490
500
|
all_dir_paths.append(dir_path)
|
|
491
501
|
|
|
492
502
|
# main loop
|
|
@@ -502,10 +512,10 @@ def train_dataset_orchestrator(list_of_dirs: list[Union[str,Path]],
|
|
|
502
512
|
save_dataframe(df=df, save_dir=save_dir, filename=filename)
|
|
503
513
|
total_saved += 1
|
|
504
514
|
except Exception as e:
|
|
505
|
-
_LOGGER.
|
|
515
|
+
_LOGGER.error(f"Failed to process file '{df_path}'. Reason: {e}")
|
|
506
516
|
continue
|
|
507
517
|
|
|
508
|
-
_LOGGER.info(f"
|
|
518
|
+
_LOGGER.info(f"{total_saved} single-target datasets were created.")
|
|
509
519
|
|
|
510
520
|
|
|
511
521
|
def train_dataset_yielder(
|
|
@@ -530,6 +540,5 @@ def train_dataset_yielder(
|
|
|
530
540
|
yield (df_features, df_target, feature_names, target_col)
|
|
531
541
|
|
|
532
542
|
|
|
533
|
-
|
|
534
543
|
def info():
|
|
535
544
|
_script_info(__all__)
|
|
@@ -1,36 +0,0 @@
|
|
|
1
|
-
dragon_ml_toolbox-8.1.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
|
|
2
|
-
dragon_ml_toolbox-8.1.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=lY4_rJPnLnMu7YBQaY-_iz1JRDcLdQzNCyeLAF1glJY,1837
|
|
3
|
-
ml_tools/ETL_engineering.py,sha256=4wwZXi9_U7xfCY70jGBaKniOeZ0m75ppxWpQBd_DmLc,39369
|
|
4
|
-
ml_tools/GUI_tools.py,sha256=n4ZZ5kEjwK5rkOCFJE41HeLFfjhpJVLUSzk9Kd9Kr_0,45410
|
|
5
|
-
ml_tools/MICE_imputation.py,sha256=oFHg-OytOzPYTzBR_wIRHhP71cMn3aupDeT59ABsXlQ,11576
|
|
6
|
-
ml_tools/ML_callbacks.py,sha256=noedVMmHZ72Odbg28zqx5wkhhvX2v-jXicKE_NCAiqU,13838
|
|
7
|
-
ml_tools/ML_datasetmaster.py,sha256=tN-GBPEwXRWFBT8r8K0v9b3Bd77DhqSH5FkjDP6BHTw,28847
|
|
8
|
-
ml_tools/ML_evaluation.py,sha256=BER5dOvSTySNzO92gm8tIpqJ5vT-s0iHMmaoly1uUH8,16018
|
|
9
|
-
ml_tools/ML_evaluation_multi.py,sha256=uVtKGYWgOLv34Xj_jz6E_HAYzNb0HwRbMwA8oFZWpUk,12395
|
|
10
|
-
ml_tools/ML_inference.py,sha256=hwtAdyDCE1xtqLgJgyOTAPck0eTmkOCJK1cM_IJSdck,22824
|
|
11
|
-
ml_tools/ML_models.py,sha256=xZiSFh7S6eitl-VjjvNpsikojDvurK8n_ueLEh6_5pM,27979
|
|
12
|
-
ml_tools/ML_optimization.py,sha256=GX-qZ2mCI3gWRCTP5w7lXrZpfGle3J_mE0O68seIoio,13475
|
|
13
|
-
ml_tools/ML_scaler.py,sha256=pGkp1nUpeuoBvbq5hUkieQdxex6kNef1mEbeS_HUCJs,7471
|
|
14
|
-
ml_tools/ML_trainer.py,sha256=6JSmEQaCPSo-S_5plNBTPw-SYgzZpyMNwiqpShJf7qU,23726
|
|
15
|
-
ml_tools/PSO_optimization.py,sha256=9Y074d-B5h4Wvp9YPiy6KAeXM-Yv6Il3gWalKvOLVgo,22705
|
|
16
|
-
ml_tools/RNN_forecast.py,sha256=2CyjBLSYYc3xLHxwLXUmP5Qv8AmV1OB_EndETNX1IBk,1956
|
|
17
|
-
ml_tools/SQL.py,sha256=bkSTmMV4CtEqa67hApYWaRxTqwAlKIc5_b28P1bnDwg,10475
|
|
18
|
-
ml_tools/VIF_factor.py,sha256=2nUMupfUoogf8o6ghoFZk_OwWhFXU0R3C9Gj0HOlI14,10415
|
|
19
|
-
ml_tools/_ML_optimization_multi.py,sha256=DrNG3Vf1uUw-3CpYfXREgSGuR4dTpLWY1F3R9j-PYqQ,9816
|
|
20
|
-
ml_tools/__init__.py,sha256=q0y9faQ6e17XCQ7eUiCZ1FJ4Bg5EQqLjZ9f_l5REUUY,41
|
|
21
|
-
ml_tools/_logger.py,sha256=TpgYguxO-CWYqqgLW0tqFjtwZ58PE_W2OCfWNGZr0n0,1175
|
|
22
|
-
ml_tools/_script_info.py,sha256=21r83LV3RubsNZ_RTEUON6RbDf7Mh4_udweNcvdF_Fk,212
|
|
23
|
-
ml_tools/custom_logger.py,sha256=nyLRxaRxkqYOFdSjI0X2BWXB8C2IU18QfmqIFKqSedI,5820
|
|
24
|
-
ml_tools/data_exploration.py,sha256=RuMHWagXrSQi1MzAMlYeBeVg7UxhVvEq8gJ9bIam2BM,27103
|
|
25
|
-
ml_tools/ensemble_evaluation.py,sha256=wnqoTPg4WYWf2A8z5XT0eSlW4snEuLCXQVj88sZKzQ4,24683
|
|
26
|
-
ml_tools/ensemble_inference.py,sha256=rtU7eUaQne615n2g7IHZCJI-OvrBCcjxbTkEIvtCGFQ,9414
|
|
27
|
-
ml_tools/ensemble_learning.py,sha256=dAyFgSTyvxJWjc_enJ_8EUoWwiekBeoNyJNxVY-kcUU,21868
|
|
28
|
-
ml_tools/handle_excel.py,sha256=J9iwIqMZemoxK49J5osSwp9Ge0h9YTKyYGbOm53hcno,13007
|
|
29
|
-
ml_tools/keys.py,sha256=HtPG8-MWh89C32A7eIlfuuA-DLwkxGkoDfwR2TGN9CQ,1074
|
|
30
|
-
ml_tools/optimization_tools.py,sha256=EL5tgNFwRo-82pbRE1CFVy9noNhULD7wprWuKadPheg,5090
|
|
31
|
-
ml_tools/path_manager.py,sha256=Z8e7w3MPqQaN8xmTnKuXZS6CIW59BFwwqGhGc00sdp4,13692
|
|
32
|
-
ml_tools/utilities.py,sha256=LqXXTovaHbA5AOKRk6Ru6DgAPAM0wPfYU70kUjYBryo,19231
|
|
33
|
-
dragon_ml_toolbox-8.1.0.dist-info/METADATA,sha256=qGTl4__H1ZsbyJHtExcDt14i8ziWXpEy2WaRAELPmTI,6778
|
|
34
|
-
dragon_ml_toolbox-8.1.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
35
|
-
dragon_ml_toolbox-8.1.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
|
|
36
|
-
dragon_ml_toolbox-8.1.0.dist-info/RECORD,,
|
|
@@ -1,231 +0,0 @@
|
|
|
1
|
-
import pandas as pd
|
|
2
|
-
import torch
|
|
3
|
-
import numpy as np
|
|
4
|
-
import evotorch
|
|
5
|
-
from evotorch.algorithms import NSGA2
|
|
6
|
-
from evotorch.logging import PandasLogger
|
|
7
|
-
from typing import Literal, Union, Tuple, List, Optional, Any, Callable
|
|
8
|
-
from pathlib import Path
|
|
9
|
-
from tqdm.auto import trange
|
|
10
|
-
from functools import partial
|
|
11
|
-
from contextlib import nullcontext
|
|
12
|
-
import matplotlib.pyplot as plt
|
|
13
|
-
import seaborn as sns
|
|
14
|
-
|
|
15
|
-
from .path_manager import make_fullpath, sanitize_filename
|
|
16
|
-
from ._logger import _LOGGER
|
|
17
|
-
from ._script_info import _script_info
|
|
18
|
-
from .ML_inference import PyTorchInferenceHandlerMulti # Using the multi-target handler
|
|
19
|
-
from .keys import PyTorchInferenceKeys
|
|
20
|
-
from .utilities import threshold_binary_values, save_dataframe
|
|
21
|
-
from .SQL import DatabaseManager # Added for SQL saving
|
|
22
|
-
|
|
23
|
-
__all__ = [
|
|
24
|
-
"create_multi_objective_problem",
|
|
25
|
-
"run_multi_objective_optimization",
|
|
26
|
-
"plot_pareto_front"
|
|
27
|
-
]
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def create_multi_objective_problem(
|
|
31
|
-
inference_handler: PyTorchInferenceHandlerMulti,
|
|
32
|
-
bounds: Tuple[List[float], List[float]],
|
|
33
|
-
binary_features: int,
|
|
34
|
-
objective_senses: Tuple[Literal["min", "max"], ...],
|
|
35
|
-
algorithm: Literal["NSGA2"] = "NSGA2",
|
|
36
|
-
population_size: int = 200,
|
|
37
|
-
**searcher_kwargs
|
|
38
|
-
) -> Tuple[evotorch.Problem, Callable[[], Any]]:
|
|
39
|
-
"""
|
|
40
|
-
Creates and configures an EvoTorch Problem and a Searcher for multi-objective optimization.
|
|
41
|
-
|
|
42
|
-
This function sets up a problem where the goal is to optimize multiple conflicting
|
|
43
|
-
objectives simultaneously, using an algorithm like NSGA2 to find the Pareto front.
|
|
44
|
-
|
|
45
|
-
Args:
|
|
46
|
-
inference_handler (PyTorchInferenceHandlerMulti): An initialized handler for the multi-target model.
|
|
47
|
-
bounds (tuple[list[float], list[float]]): Lower and upper bounds for the solution features.
|
|
48
|
-
binary_features (int): Number of binary features at the end of the feature vector.
|
|
49
|
-
objective_senses (Tuple[Literal["min", "max"], ...]): A tuple specifying the optimization
|
|
50
|
-
goal for each target (e.g., ("max", "min", "max")). The length of this tuple
|
|
51
|
-
must match the number of outputs from the model.
|
|
52
|
-
algorithm (str): The multi-objective search algorithm to use. Currently supports "NSGA2".
|
|
53
|
-
population_size (int): The number of solutions in each generation.
|
|
54
|
-
**searcher_kwargs: Additional keyword arguments for the search algorithm's constructor.
|
|
55
|
-
|
|
56
|
-
Returns:
|
|
57
|
-
A tuple containing the configured multi-objective Problem and the Searcher factory.
|
|
58
|
-
"""
|
|
59
|
-
lower_bounds, upper_bounds = list(bounds[0]), list(bounds[1])
|
|
60
|
-
|
|
61
|
-
if binary_features > 0:
|
|
62
|
-
lower_bounds.extend([0.45] * binary_features)
|
|
63
|
-
upper_bounds.extend([0.55] * binary_features)
|
|
64
|
-
|
|
65
|
-
solution_length = len(lower_bounds)
|
|
66
|
-
device = inference_handler.device
|
|
67
|
-
|
|
68
|
-
def fitness_func(solution_tensor: torch.Tensor) -> torch.Tensor:
|
|
69
|
-
"""
|
|
70
|
-
The fitness function for a multi-objective problem.
|
|
71
|
-
It returns the entire output tensor from the model. EvoTorch handles the rest.
|
|
72
|
-
"""
|
|
73
|
-
# The handler returns a tensor of shape [batch_size, num_targets]
|
|
74
|
-
predictions = inference_handler.predict_batch(solution_tensor)[PyTorchInferenceKeys.PREDICTIONS]
|
|
75
|
-
return predictions
|
|
76
|
-
|
|
77
|
-
if algorithm == "NSGA2":
|
|
78
|
-
problem = evotorch.Problem(
|
|
79
|
-
objective_sense=objective_senses,
|
|
80
|
-
objective_func=fitness_func,
|
|
81
|
-
solution_length=solution_length,
|
|
82
|
-
bounds=(lower_bounds, upper_bounds),
|
|
83
|
-
device=device,
|
|
84
|
-
vectorized=True,
|
|
85
|
-
num_actors='max' # Use available CPU cores
|
|
86
|
-
)
|
|
87
|
-
SearcherClass = NSGA2
|
|
88
|
-
if 'popsize' not in searcher_kwargs:
|
|
89
|
-
searcher_kwargs['popsize'] = population_size
|
|
90
|
-
else:
|
|
91
|
-
raise ValueError(f"Unknown multi-objective algorithm '{algorithm}'.")
|
|
92
|
-
|
|
93
|
-
searcher_factory = partial(SearcherClass, problem, **searcher_kwargs)
|
|
94
|
-
return problem, searcher_factory
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
def run_multi_objective_optimization(
|
|
98
|
-
problem: evotorch.Problem,
|
|
99
|
-
searcher_factory: Callable[[], Any],
|
|
100
|
-
num_generations: int,
|
|
101
|
-
run_name: str,
|
|
102
|
-
binary_features: int,
|
|
103
|
-
save_dir: Union[str, Path],
|
|
104
|
-
feature_names: List[str],
|
|
105
|
-
target_names: List[str],
|
|
106
|
-
save_format: Literal['csv', 'sqlite', 'both'] = 'csv',
|
|
107
|
-
verbose: bool = True
|
|
108
|
-
):
|
|
109
|
-
"""
|
|
110
|
-
Runs the multi-objective evolutionary optimization process to find the Pareto front.
|
|
111
|
-
|
|
112
|
-
This function executes a multi-objective algorithm (like NSGA2) and saves the
|
|
113
|
-
entire set of non-dominated solutions (the Pareto front) to the specified format(s).
|
|
114
|
-
It also generates and saves a plot of the Pareto front.
|
|
115
|
-
|
|
116
|
-
Args:
|
|
117
|
-
problem (evotorch.Problem): The configured multi-objective problem.
|
|
118
|
-
searcher_factory (Callable): A factory function to generate a fresh searcher instance.
|
|
119
|
-
num_generations (int): The number of generations to run the algorithm.
|
|
120
|
-
run_name (str): A name for this optimization run, used for filenames/table names.
|
|
121
|
-
binary_features (int): Number of binary features in the solution vector.
|
|
122
|
-
save_dir (str | Path): The directory where the result files will be saved.
|
|
123
|
-
feature_names (List[str]): Names of the solution features for labeling columns.
|
|
124
|
-
target_names (List[str]): Names of the target objectives for labeling columns.
|
|
125
|
-
save_format (str): The format to save results in ('csv', 'sqlite', or 'both').
|
|
126
|
-
verbose (bool): If True, attaches a logger and saves the evolution history.
|
|
127
|
-
"""
|
|
128
|
-
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
129
|
-
sanitized_run_name = sanitize_filename(run_name)
|
|
130
|
-
|
|
131
|
-
if len(target_names) != problem.num_objectives:
|
|
132
|
-
raise ValueError("The number of `target_names` must match the number of objectives in the problem.")
|
|
133
|
-
|
|
134
|
-
searcher = searcher_factory()
|
|
135
|
-
_LOGGER.info(f"🤖 Starting multi-objective optimization with {searcher.__class__.__name__} for {num_generations} generations...")
|
|
136
|
-
|
|
137
|
-
logger = PandasLogger(searcher) if verbose else None
|
|
138
|
-
searcher.run(num_generations)
|
|
139
|
-
|
|
140
|
-
pareto_front = searcher.status["pareto_front"]
|
|
141
|
-
_LOGGER.info(f"✅ Optimization complete. Found {len(pareto_front)} non-dominated solutions.")
|
|
142
|
-
|
|
143
|
-
solutions_np = pareto_front.values.cpu().numpy()
|
|
144
|
-
objectives_np = pareto_front.evals.cpu().numpy()
|
|
145
|
-
|
|
146
|
-
if binary_features > 0:
|
|
147
|
-
solutions_np = threshold_binary_values(input_array=solutions_np, binary_values=binary_features)
|
|
148
|
-
|
|
149
|
-
results_df = pd.DataFrame(solutions_np, columns=feature_names)
|
|
150
|
-
objective_cols = []
|
|
151
|
-
for i, name in enumerate(target_names):
|
|
152
|
-
col_name = f"predicted_{name}"
|
|
153
|
-
results_df[col_name] = objectives_np[:, i]
|
|
154
|
-
objective_cols.append(col_name)
|
|
155
|
-
|
|
156
|
-
# --- Saving Logic ---
|
|
157
|
-
if save_format in ['csv', 'both']:
|
|
158
|
-
csv_path = save_path / f"pareto_front_{sanitized_run_name}.csv"
|
|
159
|
-
results_df.to_csv(csv_path, index=False)
|
|
160
|
-
_LOGGER.info(f"📄 Pareto front data saved to '{csv_path.name}'")
|
|
161
|
-
|
|
162
|
-
if save_format in ['sqlite', 'both']:
|
|
163
|
-
db_path = save_path / "Optimization_Multi.db"
|
|
164
|
-
with DatabaseManager(db_path) as db:
|
|
165
|
-
db.insert_from_dataframe(
|
|
166
|
-
table_name=sanitized_run_name,
|
|
167
|
-
df=results_df,
|
|
168
|
-
if_exists='replace'
|
|
169
|
-
)
|
|
170
|
-
_LOGGER.info(f"🗃️ Pareto front data saved to table '{sanitized_run_name}' in '{db_path.name}'")
|
|
171
|
-
|
|
172
|
-
# --- Plotting Logic ---
|
|
173
|
-
plot_pareto_front(
|
|
174
|
-
results_df,
|
|
175
|
-
objective_cols=objective_cols,
|
|
176
|
-
save_path=save_path / f"pareto_plot_{sanitized_run_name}.svg"
|
|
177
|
-
)
|
|
178
|
-
|
|
179
|
-
if logger:
|
|
180
|
-
log_df = logger.to_dataframe()
|
|
181
|
-
save_dataframe(df=log_df, save_dir=save_path / "EvolutionLogs", filename=f"log_{sanitized_run_name}")
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
def plot_pareto_front(results_df: pd.DataFrame, objective_cols: List[str], save_path: Path):
|
|
185
|
-
"""
|
|
186
|
-
Generates and saves a plot of the Pareto front.
|
|
187
|
-
|
|
188
|
-
- For 2 objectives, it creates a 2D scatter plot.
|
|
189
|
-
- For 3 objectives, it creates a 3D scatter plot.
|
|
190
|
-
- For >3 objectives, it creates a scatter plot matrix (pairs plot).
|
|
191
|
-
|
|
192
|
-
Args:
|
|
193
|
-
results_df (pd.DataFrame): DataFrame containing the optimization results.
|
|
194
|
-
objective_cols (List[str]): The names of the columns that hold the objective values.
|
|
195
|
-
save_path (Path): The full path (including filename) to save the SVG plot.
|
|
196
|
-
"""
|
|
197
|
-
num_objectives = len(objective_cols)
|
|
198
|
-
_LOGGER.info(f"🎨 Generating Pareto front plot for {num_objectives} objectives...")
|
|
199
|
-
|
|
200
|
-
plt.style.use('seaborn-v0_8-whitegrid')
|
|
201
|
-
|
|
202
|
-
if num_objectives == 2:
|
|
203
|
-
fig, ax = plt.subplots(figsize=(8, 6), dpi=120)
|
|
204
|
-
ax.scatter(results_df[objective_cols[0]], results_df[objective_cols[1]], alpha=0.7, edgecolors='k')
|
|
205
|
-
ax.set_xlabel(objective_cols[0])
|
|
206
|
-
ax.set_ylabel(objective_cols[1])
|
|
207
|
-
ax.set_title("Pareto Front (2D)")
|
|
208
|
-
|
|
209
|
-
elif num_objectives == 3:
|
|
210
|
-
fig = plt.figure(figsize=(9, 7), dpi=120)
|
|
211
|
-
ax = fig.add_subplot(111, projection='3d')
|
|
212
|
-
ax.scatter(results_df[objective_cols[0]], results_df[objective_cols[1]], results_df[objective_cols[2]], alpha=0.7, depthshade=True)
|
|
213
|
-
ax.set_xlabel(objective_cols[0])
|
|
214
|
-
ax.set_ylabel(objective_cols[1])
|
|
215
|
-
ax.set_zlabel(objective_cols[2])
|
|
216
|
-
ax.set_title("Pareto Front (3D)")
|
|
217
|
-
|
|
218
|
-
else: # > 3 objectives
|
|
219
|
-
_LOGGER.info(" -> More than 3 objectives found, generating a scatter plot matrix.")
|
|
220
|
-
g = sns.pairplot(results_df[objective_cols], diag_kind="kde", plot_kws={'alpha': 0.6})
|
|
221
|
-
g.fig.suptitle("Pareto Front (Pairs Plot)", y=1.02)
|
|
222
|
-
plt.savefig(save_path, bbox_inches='tight')
|
|
223
|
-
plt.close()
|
|
224
|
-
_LOGGER.info(f"📊 Pareto plot saved to '{save_path.name}'")
|
|
225
|
-
return
|
|
226
|
-
|
|
227
|
-
plt.tight_layout()
|
|
228
|
-
plt.savefig(save_path)
|
|
229
|
-
plt.close()
|
|
230
|
-
_LOGGER.info(f"📊 Pareto plot saved to '{save_path.name}'")
|
|
231
|
-
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|