dragon-ml-toolbox 19.7.0__py3-none-any.whl → 19.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-19.7.0.dist-info → dragon_ml_toolbox-19.8.0.dist-info}/METADATA +1 -1
- {dragon_ml_toolbox-19.7.0.dist-info → dragon_ml_toolbox-19.8.0.dist-info}/RECORD +10 -10
- ml_tools/ML_configuration.py +4 -2
- ml_tools/_core/_ML_configuration.py +79 -4
- ml_tools/_core/_ML_optimization_pareto.py +118 -64
- ml_tools/_core/_keys.py +5 -0
- {dragon_ml_toolbox-19.7.0.dist-info → dragon_ml_toolbox-19.8.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-19.7.0.dist-info → dragon_ml_toolbox-19.8.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-19.7.0.dist-info → dragon_ml_toolbox-19.8.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-19.7.0.dist-info → dragon_ml_toolbox-19.8.0.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
dragon_ml_toolbox-19.
|
|
2
|
-
dragon_ml_toolbox-19.
|
|
1
|
+
dragon_ml_toolbox-19.8.0.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
|
|
2
|
+
dragon_ml_toolbox-19.8.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=XBLtvGjvBf-q93a5iylHj94Lm78UzInC-3Cii01jc6I,3127
|
|
3
3
|
ml_tools/ETL_cleaning.py,sha256=cKXyRFaaFs_beAGDnQM54xnML671kq-yJEGjHafW-20,351
|
|
4
4
|
ml_tools/ETL_engineering.py,sha256=cwh1FhtNdUHllUDvho-x3SIVj4KwG_rFQR6VYzWUg0U,898
|
|
5
5
|
ml_tools/GUI_tools.py,sha256=O89rG8WQv6GY1DiphQjIsPzXFCQID6te7q_Sgt1iTkQ,294
|
|
@@ -8,7 +8,7 @@ ml_tools/MICE_imputation.py,sha256=tpLM-rdq4sKbc2GHfj7UrkS3DmBZ3B_DlbrklWbI7gI,3
|
|
|
8
8
|
ml_tools/ML_callbacks.py,sha256=hrfsIpGkQ1G4Ucfio8JDO1TWjiluuLHCmE7r0ScqxNs,218
|
|
9
9
|
ml_tools/ML_chaining_inference.py,sha256=-JD-LbPtFQkEEWyLUuszWvsqE6nbgkKaQBjrwmBPer0,124
|
|
10
10
|
ml_tools/ML_chaining_utilities.py,sha256=TmiVea_66qfB2l3UEVua4Wb5Sg1D75bSz_-Js3DudfA,360
|
|
11
|
-
ml_tools/ML_configuration.py,sha256=
|
|
11
|
+
ml_tools/ML_configuration.py,sha256=R8ca9q6W_Lm8lQ48qmxWfdMeHJ5o9hmcHhVdekrY_UQ,2730
|
|
12
12
|
ml_tools/ML_configuration_pytab.py,sha256=6BdyL8sdAp6SDCM1DQrKZKo3yXnEgPX8mWXOaYVMhp0,257
|
|
13
13
|
ml_tools/ML_datasetmaster.py,sha256=bbT29BOGjUThcYctd2eA9K4Y6wKU6sewFMZ7tjVgpqo,154
|
|
14
14
|
ml_tools/ML_evaluation.py,sha256=My7W2IDPca7cMgmJoGyqqVzFL36ssaXA5f4MqKtvWBA,319
|
|
@@ -58,7 +58,7 @@ ml_tools/_core/_MICE_imputation.py,sha256=_juIymUnNDRWjSLepL8Ee_PncoShbxjR7YtqTt
|
|
|
58
58
|
ml_tools/_core/_ML_callbacks.py,sha256=qtCrVFHTq-nk4NIsAdwIkfkKwFXX6I-6PoCgqZELp70,16734
|
|
59
59
|
ml_tools/_core/_ML_chaining_inference.py,sha256=vXUPZzuQ2yKU71kkvUsE0xPo0hN-Yu6gfnL0JbXoRjI,7783
|
|
60
60
|
ml_tools/_core/_ML_chaining_utilities.py,sha256=nsYowgRbkIYuzRiHlqsM3tnC3c-8O73CY8DHUF14XL0,19248
|
|
61
|
-
ml_tools/_core/_ML_configuration.py,sha256=
|
|
61
|
+
ml_tools/_core/_ML_configuration.py,sha256=6lKod_NuXSj0ElYmkkwnRxZEiZctMlX1x4b0ByRKKhg,52281
|
|
62
62
|
ml_tools/_core/_ML_configuration_pytab.py,sha256=C3e4iScqdRePVDoqnic6xXMOW7DNYqpgTCeaFDyMdL4,3286
|
|
63
63
|
ml_tools/_core/_ML_datasetmaster.py,sha256=yU1BMtzz6XumMWCetVACrRLk7WJQwmYhaQ-VAWu9Ots,32043
|
|
64
64
|
ml_tools/_core/_ML_evaluation.py,sha256=bu8qlYzhWSC1B7wNfCC5TSF-oed-uP8EF7TV45VTiBM,37325
|
|
@@ -70,7 +70,7 @@ ml_tools/_core/_ML_models.py,sha256=8FUx4-TVghlBF9srh1_5UxovrWPU7YEZ6XXLqwJei88,
|
|
|
70
70
|
ml_tools/_core/_ML_models_advanced.py,sha256=oU6M5FEBMQ9yPp32cziWh3bz8SXRho07vFMC8ZDVcuU,45002
|
|
71
71
|
ml_tools/_core/_ML_models_pytab.py,sha256=EHHnDG02ghcJORy2gipm3NcrlzL0qygD44o7QGmT1Zs,26297
|
|
72
72
|
ml_tools/_core/_ML_optimization.py,sha256=b1qfHiGyvVoj-ENqDbHTf1jNx55niUWE9KEZJv3vg80,28253
|
|
73
|
-
ml_tools/_core/_ML_optimization_pareto.py,sha256=
|
|
73
|
+
ml_tools/_core/_ML_optimization_pareto.py,sha256=fad4UjW5TDbCgIsVFk1qmkq8DnU5sahFFuC2DgKAQ3I,36889
|
|
74
74
|
ml_tools/_core/_ML_scaler.py,sha256=Nhu6qli_QezHQi5NKhRb8Z51bBJgzk2nEp_yW4B9H4U,8134
|
|
75
75
|
ml_tools/_core/_ML_sequence_datasetmaster.py,sha256=0YVOPf-y4ZNdgUxropXUWrmInNyGYaUYprYvXf31n9U,17811
|
|
76
76
|
ml_tools/_core/_ML_sequence_evaluation.py,sha256=AiPHtZ9DRpE6zL9n3Tp5eGGD9vrYRkLbZ0Nc274mL7I,8069
|
|
@@ -92,7 +92,7 @@ ml_tools/_core/_ensemble_evaluation.py,sha256=17lWl4bWLT1BAMv_fhGf2D3wy-F4jx0Hgn
|
|
|
92
92
|
ml_tools/_core/_ensemble_inference.py,sha256=PfZG-r65Vw3IAmBJZg9W0zYGEe-QbhfUh_rd2ho-rr8,8610
|
|
93
93
|
ml_tools/_core/_ensemble_learning.py,sha256=X8ghbjDOLMENCWdISXLhDlHQtR3C6SW1tkTBAcfRRPY,22016
|
|
94
94
|
ml_tools/_core/_excel_handler.py,sha256=gV4rSIsiowb0xllpEJxzUKaYDDVpmP_lxs9wZA76-cc,14050
|
|
95
|
-
ml_tools/_core/_keys.py,sha256=
|
|
95
|
+
ml_tools/_core/_keys.py,sha256=UpTLHMG1j4FB7hCItcqnfAAuSVMK3Rf-i7jcu6Wkf-Y,6836
|
|
96
96
|
ml_tools/_core/_logger.py,sha256=86Ge0sDE_WgwsZBglQRYPyFYX3lcsIo0NzszNPzlxuk,5254
|
|
97
97
|
ml_tools/_core/_math_utilities.py,sha256=IlXAiZgTcLtus03jJOBOyF9ZCQDf8qLGjrCHu9Mrgak,9091
|
|
98
98
|
ml_tools/_core/_models_advanced_base.py,sha256=ceW0V_CcfOnSFqHlxUhVU8-5mtQq4tFyo8TX-xVexrY,4982
|
|
@@ -104,7 +104,7 @@ ml_tools/_core/_schema.py,sha256=TM5WVVMoKOvr_Bc2z34sU_gzKlM465PRKTgdZaEOkGY,140
|
|
|
104
104
|
ml_tools/_core/_script_info.py,sha256=21r83LV3RubsNZ_RTEUON6RbDf7Mh4_udweNcvdF_Fk,212
|
|
105
105
|
ml_tools/_core/_serde.py,sha256=tsI4EO2Y7jrBMmbQ1pinDsPOrOg-SaPuB-Dt40q0taE,5609
|
|
106
106
|
ml_tools/_core/_utilities.py,sha256=iA8fLWdhsIx4ut2Dp8M_OyU0Y3PPLgGdIklyl17x6xk,22560
|
|
107
|
-
dragon_ml_toolbox-19.
|
|
108
|
-
dragon_ml_toolbox-19.
|
|
109
|
-
dragon_ml_toolbox-19.
|
|
110
|
-
dragon_ml_toolbox-19.
|
|
107
|
+
dragon_ml_toolbox-19.8.0.dist-info/METADATA,sha256=ywnJLv63NUVz3LMgHTSbOsRZl95rW02_Sozqwq2u-p0,8764
|
|
108
|
+
dragon_ml_toolbox-19.8.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
109
|
+
dragon_ml_toolbox-19.8.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
|
|
110
|
+
dragon_ml_toolbox-19.8.0.dist-info/RECORD,,
|
ml_tools/ML_configuration.py
CHANGED
|
@@ -37,7 +37,8 @@ from ._core._ML_configuration import (
|
|
|
37
37
|
DragonAutoIntParams,
|
|
38
38
|
|
|
39
39
|
# --- Training Config ---
|
|
40
|
-
DragonTrainingConfig,
|
|
40
|
+
DragonTrainingConfig,
|
|
41
|
+
DragonParetoConfig,
|
|
41
42
|
info
|
|
42
43
|
)
|
|
43
44
|
|
|
@@ -80,5 +81,6 @@ __all__ = [
|
|
|
80
81
|
"DragonAutoIntParams",
|
|
81
82
|
|
|
82
83
|
# --- Training Config ---
|
|
83
|
-
"DragonTrainingConfig"
|
|
84
|
+
"DragonTrainingConfig",
|
|
85
|
+
"DragonParetoConfig",
|
|
84
86
|
]
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
from typing import Union, Optional, List, Any, Dict, Literal
|
|
1
|
+
from typing import Union, Optional, List, Any, Dict, Literal, Tuple
|
|
2
|
+
from pathlib import Path
|
|
2
3
|
from collections.abc import Mapping
|
|
3
4
|
import numpy as np
|
|
4
5
|
|
|
@@ -51,7 +52,8 @@ __all__ = [
|
|
|
51
52
|
"DragonAutoIntParams",
|
|
52
53
|
|
|
53
54
|
# --- Training Config ---
|
|
54
|
-
"DragonTrainingConfig"
|
|
55
|
+
"DragonTrainingConfig",
|
|
56
|
+
"DragonParetoConfig"
|
|
55
57
|
]
|
|
56
58
|
|
|
57
59
|
|
|
@@ -445,6 +447,9 @@ class _BaseModelParams(Mapping):
|
|
|
445
447
|
if isinstance(v, FeatureSchema):
|
|
446
448
|
# Force the repr() string, otherwise json.dump treats it as a list
|
|
447
449
|
clean_dict[k] = repr(v)
|
|
450
|
+
elif isinstance(v, Path):
|
|
451
|
+
# JSON cannot serialize Path objects, convert to string
|
|
452
|
+
clean_dict[k] = str(v)
|
|
448
453
|
else:
|
|
449
454
|
clean_dict[k] = v
|
|
450
455
|
return clean_dict
|
|
@@ -646,6 +651,8 @@ class DragonTrainingConfig(_BaseModelParams):
|
|
|
646
651
|
Configuration object for the training process.
|
|
647
652
|
|
|
648
653
|
Can be unpacked as a dictionary for logging or accessed as an object.
|
|
654
|
+
|
|
655
|
+
Accepts arbitrary keyword arguments which are set as instance attributes.
|
|
649
656
|
"""
|
|
650
657
|
def __init__(self,
|
|
651
658
|
validation_size: float,
|
|
@@ -656,7 +663,7 @@ class DragonTrainingConfig(_BaseModelParams):
|
|
|
656
663
|
early_stop_patience: Optional[int] = None,
|
|
657
664
|
scheduler_patience: Optional[int] = None,
|
|
658
665
|
scheduler_lr_factor: Optional[float] = None,
|
|
659
|
-
|
|
666
|
+
**kwargs: Any) -> None:
|
|
660
667
|
self.validation_size = validation_size
|
|
661
668
|
self.test_size = test_size
|
|
662
669
|
self.initial_learning_rate = initial_learning_rate
|
|
@@ -665,7 +672,75 @@ class DragonTrainingConfig(_BaseModelParams):
|
|
|
665
672
|
self.early_stop_patience = early_stop_patience
|
|
666
673
|
self.scheduler_patience = scheduler_patience
|
|
667
674
|
self.scheduler_lr_factor = scheduler_lr_factor
|
|
668
|
-
|
|
675
|
+
|
|
676
|
+
# Process kwargs with validation
|
|
677
|
+
for key, value in kwargs.items():
|
|
678
|
+
# Python guarantees 'key' is a string for **kwargs
|
|
679
|
+
|
|
680
|
+
# Allow None in value
|
|
681
|
+
if value is None:
|
|
682
|
+
setattr(self, key, value)
|
|
683
|
+
continue
|
|
684
|
+
|
|
685
|
+
if isinstance(value, dict):
|
|
686
|
+
_LOGGER.error("Nested dictionaries are not supported, unpack them first.")
|
|
687
|
+
raise TypeError()
|
|
688
|
+
|
|
689
|
+
# Check if value is a number or a string or a JSON supported type, except dict
|
|
690
|
+
if not isinstance(value, (str, int, float, bool, list, tuple)):
|
|
691
|
+
_LOGGER.error(f"Invalid type for configuration '{key}': {type(value).__name__}")
|
|
692
|
+
raise TypeError()
|
|
693
|
+
|
|
694
|
+
setattr(self, key, value)
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
class DragonParetoConfig(_BaseModelParams):
|
|
698
|
+
"""
|
|
699
|
+
Configuration object for the Pareto Optimization process.
|
|
700
|
+
"""
|
|
701
|
+
def __init__(self,
|
|
702
|
+
save_directory: Union[str, Path],
|
|
703
|
+
target_objectives: Dict[str, Literal["min", "max"]],
|
|
704
|
+
continuous_bounds_map: Union[Dict[str, Tuple[float, float]], Dict[str, List[float]]],
|
|
705
|
+
columns_to_round: Optional[List[str]] = None,
|
|
706
|
+
population_size: int = 400,
|
|
707
|
+
generations: int = 1000,
|
|
708
|
+
solutions_filename: str = "ParetoSolutions",
|
|
709
|
+
float_precision: int = 4,
|
|
710
|
+
log_interval: int = 10,
|
|
711
|
+
plot_size: Tuple[int, int] = (10, 7),
|
|
712
|
+
plot_font_size: int = 16,
|
|
713
|
+
discretize_start_at_zero: bool = True):
|
|
714
|
+
"""
|
|
715
|
+
Configure the Pareto Optimizer.
|
|
716
|
+
|
|
717
|
+
Args:
|
|
718
|
+
save_directory (str | Path): Directory to save artifacts.
|
|
719
|
+
target_objectives (Dict[str, "min"|"max"]): Dictionary mapping target names to optimization direction.
|
|
720
|
+
Example: {"price": "max", "error": "min"}
|
|
721
|
+
continuous_bounds_map (Dict): Bounds for continuous features {name: (min, max)}.
|
|
722
|
+
columns_to_round (List[str] | None): List of continuous column names that should be rounded to the nearest integer.
|
|
723
|
+
population_size (int): Size of the genetic population.
|
|
724
|
+
generations (int): Number of generations to run.
|
|
725
|
+
solutions_filename (str): Filename for saving Pareto solutions.
|
|
726
|
+
float_precision (int): Number of decimal places to round standard float columns.
|
|
727
|
+
log_interval (int): Interval for logging progress.
|
|
728
|
+
plot_size (Tuple[int, int]): Size of the 2D plots.
|
|
729
|
+
plot_font_size (int): Font size for plot text.
|
|
730
|
+
discretize_start_at_zero (bool): Categorical encoding start index. True=0, False=1.
|
|
731
|
+
"""
|
|
732
|
+
self.save_directory = save_directory
|
|
733
|
+
self.target_objectives = target_objectives
|
|
734
|
+
self.continuous_bounds_map = continuous_bounds_map
|
|
735
|
+
self.columns_to_round = columns_to_round
|
|
736
|
+
self.population_size = population_size
|
|
737
|
+
self.generations = generations
|
|
738
|
+
self.solutions_filename = solutions_filename
|
|
739
|
+
self.float_precision = float_precision
|
|
740
|
+
self.log_interval = log_interval
|
|
741
|
+
self.plot_size = plot_size
|
|
742
|
+
self.plot_font_size = plot_font_size
|
|
743
|
+
self.discretize_start_at_zero = discretize_start_at_zero
|
|
669
744
|
|
|
670
745
|
|
|
671
746
|
# ----------------------------
|
|
@@ -20,6 +20,7 @@ from evotorch.operators import functional as func_ops
|
|
|
20
20
|
from ._SQL import DragonSQL
|
|
21
21
|
from ._ML_inference import DragonInferenceHandler
|
|
22
22
|
from ._ML_chaining_inference import DragonChainInference
|
|
23
|
+
from ._ML_configuration import DragonParetoConfig
|
|
23
24
|
from ._optimization_tools import create_optimization_bounds, plot_optimal_feature_distributions_from_dataframe
|
|
24
25
|
from ._math_utilities import discretize_categorical_values
|
|
25
26
|
from ._utilities import save_dataframe_filename
|
|
@@ -57,26 +58,21 @@ class DragonParetoOptimizer:
|
|
|
57
58
|
def __init__(self,
|
|
58
59
|
inference_handler: Union[DragonInferenceHandler, DragonChainInference],
|
|
59
60
|
schema: FeatureSchema,
|
|
60
|
-
|
|
61
|
-
continuous_bounds_map: Union[Dict[str, Tuple[float, float]], Dict[str, List[float]]],
|
|
62
|
-
population_size: int = 400,
|
|
63
|
-
discretize_start_at_zero: bool = True):
|
|
61
|
+
config: DragonParetoConfig):
|
|
64
62
|
"""
|
|
65
63
|
Initialize the Pareto Optimizer.
|
|
66
64
|
|
|
67
65
|
Args:
|
|
68
66
|
inference_handler (DragonInferenceHandler | DragonChainInference): Validated model handler.
|
|
69
67
|
schema (FeatureSchema): Feature schema for bounds and types.
|
|
70
|
-
|
|
71
|
-
Example: {"price": "max", "error": "min"}
|
|
72
|
-
continuous_bounds_map (Dict): Bounds for continuous features {name: (min, max)}.
|
|
73
|
-
population_size (int): Size of the genetic population.
|
|
74
|
-
discretize_start_at_zero (bool): Categorical encoding start index.
|
|
68
|
+
config (DragonParetoConfig): Configuration for the Pareto optimizer.
|
|
75
69
|
"""
|
|
76
70
|
self.inference_handler = inference_handler
|
|
77
71
|
self.schema = schema
|
|
78
|
-
self.
|
|
79
|
-
|
|
72
|
+
self.config = config
|
|
73
|
+
|
|
74
|
+
self.target_objectives = config.target_objectives
|
|
75
|
+
self.discretize_start_at_zero = config.discretize_start_at_zero
|
|
80
76
|
|
|
81
77
|
# Initialize state for results
|
|
82
78
|
self.pareto_front: Optional[pd.DataFrame] = None
|
|
@@ -106,7 +102,7 @@ class DragonParetoOptimizer:
|
|
|
106
102
|
|
|
107
103
|
available_targets = self.inference_handler.target_ids
|
|
108
104
|
|
|
109
|
-
for name, direction in target_objectives.items():
|
|
105
|
+
for name, direction in self.target_objectives.items():
|
|
110
106
|
if name not in available_targets:
|
|
111
107
|
_LOGGER.error(f"Target '{name}' not found in model targets: {available_targets}")
|
|
112
108
|
raise ValueError()
|
|
@@ -124,8 +120,8 @@ class DragonParetoOptimizer:
|
|
|
124
120
|
# Uses the external tool which reads the schema to set correct bounds for both continuous and categorical
|
|
125
121
|
bounds = create_optimization_bounds(
|
|
126
122
|
schema=schema,
|
|
127
|
-
continuous_bounds_map=continuous_bounds_map,
|
|
128
|
-
start_at_zero=discretize_start_at_zero
|
|
123
|
+
continuous_bounds_map=config.continuous_bounds_map,
|
|
124
|
+
start_at_zero=self.discretize_start_at_zero
|
|
129
125
|
)
|
|
130
126
|
self.lower_bounds = list(bounds[0])
|
|
131
127
|
self.upper_bounds = list(bounds[1])
|
|
@@ -136,7 +132,7 @@ class DragonParetoOptimizer:
|
|
|
136
132
|
target_indices=self.target_indices, # Used by Standard Handler
|
|
137
133
|
target_names=self.ordered_target_names, # Used by Chain Handler
|
|
138
134
|
categorical_index_map=schema.categorical_index_map,
|
|
139
|
-
discretize_start_at_zero=discretize_start_at_zero,
|
|
135
|
+
discretize_start_at_zero=self.discretize_start_at_zero,
|
|
140
136
|
is_chain=self.is_chain
|
|
141
137
|
)
|
|
142
138
|
|
|
@@ -155,7 +151,7 @@ class DragonParetoOptimizer:
|
|
|
155
151
|
# GeneticAlgorithm. It automatically applies NSGA-II logic (Pareto sorting) when problem is multi-objective.
|
|
156
152
|
self.algorithm = GeneticAlgorithm(
|
|
157
153
|
self.problem,
|
|
158
|
-
popsize=population_size,
|
|
154
|
+
popsize=config.population_size,
|
|
159
155
|
operators=[
|
|
160
156
|
SimulatedBinaryCrossOver(self.problem, tournament_size=3, eta=20.0, cross_over_rate=1.0),
|
|
161
157
|
GaussianMutation(self.problem, stdev=0.1)
|
|
@@ -163,21 +159,17 @@ class DragonParetoOptimizer:
|
|
|
163
159
|
re_evaluate=False # model is deterministic
|
|
164
160
|
)
|
|
165
161
|
|
|
166
|
-
def run(self
|
|
167
|
-
generations: int,
|
|
168
|
-
save_dir: Union[str, Path],
|
|
169
|
-
log_interval: int = 10) -> pd.DataFrame:
|
|
162
|
+
def run(self) -> pd.DataFrame:
|
|
170
163
|
"""
|
|
171
164
|
Execute the optimization with progress tracking and periodic logging.
|
|
172
165
|
|
|
173
|
-
Args:
|
|
174
|
-
generations (int): Number of generations to evolve.
|
|
175
|
-
save_dir (str|Path): Directory to save results and plots.
|
|
176
|
-
log_interval (int): How often (in generations) to log population statistics.
|
|
177
|
-
|
|
178
166
|
Returns:
|
|
179
167
|
pd.DataFrame: A DataFrame containing the non-dominated solutions (Pareto Front).
|
|
180
168
|
"""
|
|
169
|
+
generations = self.config.generations
|
|
170
|
+
save_dir = self.config.save_directory
|
|
171
|
+
log_interval = self.config.log_interval
|
|
172
|
+
|
|
181
173
|
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
182
174
|
log_file = save_path / "optimization_log.txt"
|
|
183
175
|
|
|
@@ -189,26 +181,41 @@ class DragonParetoOptimizer:
|
|
|
189
181
|
with open(log_file, "w") as f:
|
|
190
182
|
f.write(f"Pareto Optimization Log - {generations} Generations\n")
|
|
191
183
|
f.write("=" * 60 + "\n")
|
|
184
|
+
|
|
185
|
+
# History tracking for visualization
|
|
186
|
+
history_records = []
|
|
192
187
|
|
|
193
188
|
# --- Optimization Loop with Progress Bar ---
|
|
194
189
|
with tqdm(total=generations, desc="Evolving Pareto Front", unit="gen") as pbar:
|
|
195
190
|
for gen in range(1, generations + 1):
|
|
196
191
|
self.algorithm.step()
|
|
197
192
|
|
|
193
|
+
# Capture stats for history (every generation for smooth plots)
|
|
194
|
+
current_evals = self.algorithm.population.evals.clone() # type: ignore
|
|
195
|
+
|
|
196
|
+
gen_stats = {}
|
|
197
|
+
for i, target_name in enumerate(self.ordered_target_names):
|
|
198
|
+
vals = current_evals[:, i]
|
|
199
|
+
v_mean = float(vals.mean())
|
|
200
|
+
v_min = float(vals.min())
|
|
201
|
+
v_max = float(vals.max())
|
|
202
|
+
|
|
203
|
+
# Store for plotting
|
|
204
|
+
history_records.append({
|
|
205
|
+
"Generation": gen,
|
|
206
|
+
"Target": target_name,
|
|
207
|
+
"Mean": v_mean,
|
|
208
|
+
"Min": v_min,
|
|
209
|
+
"Max": v_max
|
|
210
|
+
})
|
|
211
|
+
|
|
212
|
+
gen_stats[target_name] = (v_mean, v_min, v_max)
|
|
213
|
+
|
|
198
214
|
# Periodic Logging of Population Stats to FILE
|
|
199
215
|
if gen % log_interval == 0 or gen == generations:
|
|
200
216
|
stats_msg = [f"Gen {gen}:"]
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
current_evals = self.algorithm.population.evals
|
|
204
|
-
|
|
205
|
-
for i, target_name in enumerate(self.ordered_target_names):
|
|
206
|
-
vals = current_evals[:, i]
|
|
207
|
-
v_mean = float(vals.mean())
|
|
208
|
-
v_min = float(vals.min())
|
|
209
|
-
v_max = float(vals.max())
|
|
210
|
-
|
|
211
|
-
stats_msg.append(f"{target_name}: {v_mean:.3f} (Range: {v_min:.3f}-{v_max:.3f})")
|
|
217
|
+
for t_name, (v_mean, v_min, v_max) in gen_stats.items():
|
|
218
|
+
stats_msg.append(f"{t_name}: {v_mean:.3f} (Range: {v_min:.3f}-{v_max:.3f})")
|
|
212
219
|
|
|
213
220
|
log_line = " | ".join(stats_msg)
|
|
214
221
|
|
|
@@ -217,6 +224,12 @@ class DragonParetoOptimizer:
|
|
|
217
224
|
f.write(log_line + "\n")
|
|
218
225
|
|
|
219
226
|
pbar.update(1)
|
|
227
|
+
|
|
228
|
+
# --- Post-Optimization Visualization ---
|
|
229
|
+
if history_records:
|
|
230
|
+
_LOGGER.debug("Generating optimization history plots...")
|
|
231
|
+
history_df = pd.DataFrame(history_records)
|
|
232
|
+
self._plot_optimization_history(history_df, save_path)
|
|
220
233
|
|
|
221
234
|
# --- Extract Pareto Front ---
|
|
222
235
|
# Manually identify the Pareto front from the final population using domination counts
|
|
@@ -289,10 +302,6 @@ class DragonParetoOptimizer:
|
|
|
289
302
|
return pareto_df
|
|
290
303
|
|
|
291
304
|
def save_solutions(self,
|
|
292
|
-
filename: str = "Pareto_Solutions",
|
|
293
|
-
save_dir: Optional[Union[str, Path]] = None,
|
|
294
|
-
columns_to_round: Optional[List[str]] = None,
|
|
295
|
-
float_precision: int = 4,
|
|
296
305
|
save_to_sql: bool = False,
|
|
297
306
|
sql_table_name: Optional[str] = None,
|
|
298
307
|
sql_if_exists: Literal['fail', 'replace', 'append'] = 'replace') -> None:
|
|
@@ -301,12 +310,8 @@ class DragonParetoOptimizer:
|
|
|
301
310
|
for specific continuous columns. Optionally saves to a SQL database.
|
|
302
311
|
|
|
303
312
|
Args:
|
|
304
|
-
save_dir (str | Path | None): Directory to save the CSV. If None, uses the optimization directory.
|
|
305
|
-
filename (str): Name of the file (without .csv extension).
|
|
306
|
-
columns_to_round (List[str], optional): List of continuous column names that should be rounded to the nearest integer.
|
|
307
|
-
float_precision (int): Number of decimal places to round standard float columns.
|
|
308
313
|
save_to_sql (bool): If True, also writes the results to a SQLite database in the save_dir.
|
|
309
|
-
sql_table_name (str, optional): Specific table name for SQL. If None, uses
|
|
314
|
+
sql_table_name (str, optional): Specific table name for SQL. If None, uses the solutions filename.
|
|
310
315
|
sql_if_exists (str): Behavior if SQL table exists ('fail', 'replace', 'append').
|
|
311
316
|
"""
|
|
312
317
|
if self.pareto_front is None:
|
|
@@ -314,11 +319,15 @@ class DragonParetoOptimizer:
|
|
|
314
319
|
raise ValueError()
|
|
315
320
|
|
|
316
321
|
# handle directory
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
+
save_path = self._metrics_dir
|
|
323
|
+
if save_path is None:
|
|
324
|
+
_LOGGER.error("No save directory found. Cannot save solutions.")
|
|
325
|
+
raise ValueError()
|
|
326
|
+
|
|
327
|
+
# unpack values from config
|
|
328
|
+
filename = self.config.solutions_filename
|
|
329
|
+
columns_to_round = self.config.columns_to_round
|
|
330
|
+
float_precision = self.config.float_precision
|
|
322
331
|
|
|
323
332
|
# Create a copy to avoid modifying the internal state
|
|
324
333
|
df_to_save = self.pareto_front.copy()
|
|
@@ -354,8 +363,6 @@ class DragonParetoOptimizer:
|
|
|
354
363
|
df_to_save[float_cols] = df_to_save[float_cols].round(float_precision)
|
|
355
364
|
|
|
356
365
|
# Save CSV
|
|
357
|
-
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
358
|
-
|
|
359
366
|
# sanitize filename and add extension if missing
|
|
360
367
|
sanitized_filename = sanitize_filename(filename)
|
|
361
368
|
csv_filename = sanitized_filename if sanitized_filename.lower().endswith(".csv") else f"{sanitized_filename}.csv"
|
|
@@ -577,7 +584,7 @@ class DragonParetoOptimizer:
|
|
|
577
584
|
"""Standard 2D scatter plot."""
|
|
578
585
|
x_name, y_name = self.ordered_target_names[0], self.ordered_target_names[1]
|
|
579
586
|
|
|
580
|
-
plt.figure(figsize=
|
|
587
|
+
plt.figure(figsize=self.config.plot_size, dpi=ParetoOptimizationKeys.DPI)
|
|
581
588
|
|
|
582
589
|
# Use a color gradient based on the Y-axis to make "better" values visually distinct
|
|
583
590
|
sns.scatterplot(
|
|
@@ -592,7 +599,7 @@ class DragonParetoOptimizer:
|
|
|
592
599
|
legend=False
|
|
593
600
|
)
|
|
594
601
|
|
|
595
|
-
plt.title(f"Pareto Front: {x_name} vs {y_name}", fontsize=
|
|
602
|
+
plt.title(f"Pareto Front: {x_name} vs {y_name}", fontsize=self.config.plot_font_size + 2, pad=ParetoOptimizationKeys.FONT_PAD)
|
|
596
603
|
plt.grid(True, linestyle='--', alpha=0.6)
|
|
597
604
|
|
|
598
605
|
# Add simple annotation for the 'corners' (extremes)
|
|
@@ -616,8 +623,7 @@ class DragonParetoOptimizer:
|
|
|
616
623
|
x_target: Union[int, str],
|
|
617
624
|
y_target: Union[int, str],
|
|
618
625
|
z_target: Union[int, str],
|
|
619
|
-
hue_target: Optional[Union[int, str]] = None
|
|
620
|
-
save_dir: Optional[Union[str, Path]] = None,):
|
|
626
|
+
hue_target: Optional[Union[int, str]] = None):
|
|
621
627
|
"""
|
|
622
628
|
Public API to generate 3D visualizations for specific targets.
|
|
623
629
|
|
|
@@ -626,15 +632,11 @@ class DragonParetoOptimizer:
|
|
|
626
632
|
y_target (int|str): Index or name of the target for the Y axis.
|
|
627
633
|
z_target (int|str): Index or name of the target for the Z axis.
|
|
628
634
|
hue_target (int|str, optional): Index or name of the target for coloring. Defaults to z_target if None.
|
|
629
|
-
save_dir (str|Path, optional): Directory to save plots. Defaults to the directory used during optimization.
|
|
630
635
|
"""
|
|
631
|
-
if
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
save_dir = self._metrics_dir
|
|
636
|
-
|
|
637
|
-
save_path_root = make_fullpath(save_dir, make=True, enforce="directory")
|
|
636
|
+
if self._metrics_dir is None:
|
|
637
|
+
_LOGGER.error("No save directory specified and no previous optimization directory found.")
|
|
638
|
+
raise ValueError()
|
|
639
|
+
save_path_root = self._metrics_dir
|
|
638
640
|
|
|
639
641
|
save_path = make_fullpath(save_path_root / ParetoOptimizationKeys.PARETO_PLOTS_DIR, make=True, enforce="directory")
|
|
640
642
|
|
|
@@ -716,6 +718,58 @@ class DragonParetoOptimizer:
|
|
|
716
718
|
html_path = sub_dir_path / f"Pareto_3D_Interactive.html"
|
|
717
719
|
fig_html.write_html(str(html_path))
|
|
718
720
|
|
|
721
|
+
def _plot_optimization_history(self, history_df: pd.DataFrame, save_dir: Path):
|
|
722
|
+
"""
|
|
723
|
+
Generates convergence plots (Mean/Min/Max) for each objective over generations.
|
|
724
|
+
|
|
725
|
+
Args:
|
|
726
|
+
history_df: DataFrame with cols [Generation, Target, Mean, Min, Max]
|
|
727
|
+
save_dir: Base directory to save plots
|
|
728
|
+
"""
|
|
729
|
+
# Create subdirectory for history plots
|
|
730
|
+
plot_dir = make_fullpath(save_dir / ParetoOptimizationKeys.HISTORY_PLOTS_DIR, make=True, enforce="directory")
|
|
731
|
+
|
|
732
|
+
unique_targets = history_df["Target"].unique()
|
|
733
|
+
|
|
734
|
+
for target in unique_targets:
|
|
735
|
+
subset = history_df[history_df["Target"] == target]
|
|
736
|
+
|
|
737
|
+
# Determine direction (just for annotation/context if needed, but plotting stats is neutral)
|
|
738
|
+
direction = self.target_objectives.get(target, "unknown")
|
|
739
|
+
|
|
740
|
+
plt.figure(figsize=self.config.plot_size, dpi=ParetoOptimizationKeys.DPI)
|
|
741
|
+
|
|
742
|
+
# Plot Mean
|
|
743
|
+
plt.plot(subset["Generation"], subset["Mean"], label="Population Mean", color="#4c72b0", linewidth=2)
|
|
744
|
+
|
|
745
|
+
# Plot Min/Max Range
|
|
746
|
+
plt.fill_between(
|
|
747
|
+
subset["Generation"],
|
|
748
|
+
subset["Min"],
|
|
749
|
+
subset["Max"],
|
|
750
|
+
color="#4c72b0",
|
|
751
|
+
alpha=0.15,
|
|
752
|
+
label="Min-Max Range"
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
# Plot extremes as dashed lines
|
|
756
|
+
plt.plot(subset["Generation"], subset["Min"], linestyle="--", color="#55a868", alpha=0.6, linewidth=1, label="Min")
|
|
757
|
+
plt.plot(subset["Generation"], subset["Max"], linestyle="--", color="#c44e52", alpha=0.6, linewidth=1, label="Max")
|
|
758
|
+
|
|
759
|
+
plt.title(f"Convergence History: {target} ({direction.upper()})", fontsize=self.config.plot_font_size + 2, pad=ParetoOptimizationKeys.FONT_PAD)
|
|
760
|
+
plt.xlabel("Generation", labelpad=ParetoOptimizationKeys.FONT_PAD, fontsize=self.config.plot_font_size)
|
|
761
|
+
plt.ylabel("Target Value", labelpad=ParetoOptimizationKeys.FONT_PAD, fontsize=self.config.plot_font_size)
|
|
762
|
+
plt.legend(loc='best', fontsize=self.config.plot_font_size)
|
|
763
|
+
plt.grid(True, linestyle="--", alpha=0.5)
|
|
764
|
+
plt.xticks(fontsize=self.config.plot_font_size - 4)
|
|
765
|
+
plt.yticks(fontsize=self.config.plot_font_size - 4)
|
|
766
|
+
|
|
767
|
+
plt.tight_layout()
|
|
768
|
+
|
|
769
|
+
fname = f"Convergence_{sanitize_filename(target)}.svg"
|
|
770
|
+
plt.savefig(plot_dir / fname, bbox_inches='tight')
|
|
771
|
+
plt.close()
|
|
772
|
+
|
|
719
773
|
class _ParetoFitnessEvaluator:
|
|
720
774
|
"""
|
|
721
775
|
Evaluates fitness for Multi-Objective optimization.
|
ml_tools/_core/_keys.py
CHANGED
|
@@ -196,6 +196,11 @@ class ParetoOptimizationKeys:
|
|
|
196
196
|
"""Used by the ML optimization pareto module."""
|
|
197
197
|
PARETO_PLOTS_DIR = "Pareto_Plots"
|
|
198
198
|
SQL_DATABASE_FILENAME = "OptimizationResults.db"
|
|
199
|
+
HISTORY_PLOTS_DIR = "History"
|
|
200
|
+
|
|
201
|
+
# Plot Config values
|
|
202
|
+
FONT_PAD = 10
|
|
203
|
+
DPI = 400
|
|
199
204
|
|
|
200
205
|
|
|
201
206
|
class OptimizationToolsKeys:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|