dragon-ml-toolbox 2.2.1__py3-none-any.whl → 2.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 2.2.1
3
+ Version: 2.4.0
4
4
  Summary: A collection of tools for data science and machine learning projects
5
5
  Author-email: Karl Loza <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -37,9 +37,11 @@ Requires-Dist: Pillow
37
37
  Provides-Extra: pytorch
38
38
  Requires-Dist: torch; extra == "pytorch"
39
39
  Requires-Dist: torchvision; extra == "pytorch"
40
+ Provides-Extra: gui
41
+ Requires-Dist: FreeSimpleGUI>=5.2; extra == "gui"
40
42
  Dynamic: license-file
41
43
 
42
- # dragon-ml-tools
44
+ # dragon-ml-toolbox
43
45
 
44
46
  A collection of Python utilities for data science and machine learning, structured as a modular package for easy reuse and installation.
45
47
 
@@ -57,7 +59,7 @@ A collection of Python utilities for data science and machine learning, structur
57
59
  Install the latest stable release from PyPI:
58
60
 
59
61
  ```bash
60
- pip install dragon-ml-tools
62
+ pip install dragon-ml-toolbox
61
63
  ```
62
64
 
63
65
  ### Via GitHub (Editable)
@@ -77,16 +79,26 @@ Install from the conda-forge channel:
77
79
  ```bash
78
80
  conda install -c conda-forge dragon-ml-toolbox
79
81
  ```
80
- **Note:** This version is outdated or broken due to dependency incompatibilities.
82
+ **Note:** This version is outdated or broken due to dependency incompatibilities. Use PyPi instead.
81
83
 
82
84
  ## Optional dependencies
83
85
 
84
- **PyTorch**, which provides different builds depending on the **platform** and **hardware acceleration** (e.g., CUDA for NVIDIA GPUs on Linux/Windows, or MPS for Apple Silicon on macOS).
86
+ ### FreeSimpleGUI
87
+
88
+ Wrapper library used to build powerful GUIs. Requires the tkinter backend.
89
+
90
+ ```bash
91
+ pip install dragon-ml-toolbox[gui]
92
+ ```
93
+
94
+ ### PyTorch
95
+
96
+ Different builds available depending on the **platform** and **hardware acceleration** (e.g., CUDA for NVIDIA GPUs on Linux/Windows, or MPS for Apple Silicon on macOS).
85
97
 
86
98
  Install the default CPU-only version with
87
99
 
88
100
  ```bash
89
- pip install dragon-ml-tools[pytorch]
101
+ pip install dragon-ml-toolbox[pytorch]
90
102
  ```
91
103
 
92
104
  To make use of GPU acceleration use the official PyTorch installation instructions:
@@ -108,6 +120,8 @@ from ml_tools.logger import custom_logger
108
120
  data_exploration
109
121
  datasetmaster
110
122
  ensemble_learning
123
+ ETL_engineering
124
+ GUI_tools
111
125
  handle_excel
112
126
  logger
113
127
  MICE_imputation
@@ -1,8 +1,9 @@
1
- dragon_ml_toolbox-2.2.1.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
2
- dragon_ml_toolbox-2.2.1.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=6cfpIeQ6D4Mcs10nkogQrkVyq1T7i2qXjjNHFoUMOyE,1892
3
- ml_tools/ETL_engineering.py,sha256=meQwdMUmAGXmrOSF5K5MaIhztvAbwxPeKnPnv8TxBi0,23283
1
+ dragon_ml_toolbox-2.4.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
2
+ dragon_ml_toolbox-2.4.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=6cfpIeQ6D4Mcs10nkogQrkVyq1T7i2qXjjNHFoUMOyE,1892
3
+ ml_tools/ETL_engineering.py,sha256=ns8HsLWZhByurvjtUUW10p7If1h1O5-btUfCRXxzkME,31568
4
+ ml_tools/GUI_tools.py,sha256=sKLBWRhwGax3QSVICEduQiTbGhQdwvW0eeHPQMiyOF0,20150
4
5
  ml_tools/MICE_imputation.py,sha256=1fovHycZMdZ6OgVh_bk8-r3wGi4rqf6rS10LOEWYaQo,11177
5
- ml_tools/PSO_optimization.py,sha256=T-wnB94DcRWuRd2M3loDVT4POtIP0MOhs-VilAf1L4E,20974
6
+ ml_tools/PSO_optimization.py,sha256=gi56mF-q6BApYwhAd9jix0xiYz595WTPcUh7afZsRJ4,25378
6
7
  ml_tools/VIF_factor.py,sha256=lpM3Z2X_iZfXUWbCbURoeI0Tb196lU0bAsRo7q6AzBM,10235
7
8
  ml_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
9
  ml_tools/_particle_swarm_optimization.py,sha256=b_eNNkA89Y40hj76KauivT8KLScH1B9wF2IXptOqkOw,22220
@@ -13,9 +14,9 @@ ml_tools/handle_excel.py,sha256=Uasx-DX7RNVQSzGHVJhX7UQ9RgBbX5H1ud1Hw_y8Kp4,1294
13
14
  ml_tools/logger.py,sha256=_k7WJdpFJj3IsjOgvjLJgUFZyF8RK3Jlgp5tAu_dLQU,4767
14
15
  ml_tools/pytorch_models.py,sha256=bpWZsrSwCvHJQkR6UfoPpElsMv9AvmiNErNHC8NYB_I,10132
15
16
  ml_tools/trainer.py,sha256=WAZ4EdrZuTOAnGXRWV3XcLNce4s7EKGf2-qchLC08Ik,15702
16
- ml_tools/utilities.py,sha256=A7Wm1ArpqFG80WKmnkYdtSzIRLvg5x-9nPNidZIbpPA,20671
17
+ ml_tools/utilities.py,sha256=T6AnNEQjUDnMAMSIJ8yZqToAVESIlEKK0bGBEm3sAUU,20670
17
18
  ml_tools/vision_helpers.py,sha256=idQ-Ugp1IdsvwXiYyhYa9G3rTRTm37YRpkQDLEpANHM,7701
18
- dragon_ml_toolbox-2.2.1.dist-info/METADATA,sha256=1Xjem3tZp5rlaFrz5_lQKdtal_jUB9lKRUIlQqYseyE,2974
19
- dragon_ml_toolbox-2.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
20
- dragon_ml_toolbox-2.2.1.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
21
- dragon_ml_toolbox-2.2.1.dist-info/RECORD,,
19
+ dragon_ml_toolbox-2.4.0.dist-info/METADATA,sha256=LewdCOSOEeCNVLrB37FD39hnESJ7lPt2voeO-nFG-es,3232
20
+ dragon_ml_toolbox-2.4.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
21
+ dragon_ml_toolbox-2.4.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
22
+ dragon_ml_toolbox-2.4.0.dist-info/RECORD,,
@@ -2,19 +2,120 @@ import polars as pl
2
2
  import re
3
3
  from typing import Literal, Union, Optional, Any, Callable, List, Dict
4
4
  from .utilities import _script_info
5
+ import pandas as pd
5
6
 
6
7
 
7
8
  __all__ = [
9
+ "ColumnCleaner",
10
+ "DataFrameCleaner"
8
11
  "TransformationRecipe",
9
12
  "DataProcessor",
10
13
  "KeywordDummifier",
11
14
  "NumberExtractor",
12
15
  "MultiNumberExtractor",
16
+ "RatioCalculator"
13
17
  "CategoryMapper",
18
+ "RegexMapper",
14
19
  "ValueBinner",
15
20
  "DateFeatureExtractor"
16
21
  ]
17
22
 
23
+ ########## EXTRACT and CLEAN ##########
24
+
25
+ class ColumnCleaner:
26
+ """
27
+ Cleans and standardizes a single pandas Series based on a dictionary of regex-to-value replacement rules.
28
+
29
+ Args:
30
+ rules (Dict[str, str]):
31
+ A dictionary where each key is a regular expression pattern and
32
+ each value is the standardized string to replace matches with.
33
+ """
34
+ def __init__(self, rules: Dict[str, str]):
35
+ if not isinstance(rules, dict):
36
+ raise TypeError("The 'rules' argument must be a dictionary.")
37
+
38
+ # Validate that all keys are valid regular expressions
39
+ for pattern in rules.keys():
40
+ try:
41
+ re.compile(pattern)
42
+ except re.error as e:
43
+ raise ValueError(f"Invalid regex pattern '{pattern}': {e}") from e
44
+
45
+ self.rules = rules
46
+
47
+ def clean(self, series: pd.Series) -> pd.Series:
48
+ """
49
+ Applies the standardization rules to the provided Series (requires string data).
50
+
51
+ Non-matching values are kept as they are.
52
+
53
+ Args:
54
+ series (pd.Series): The pandas Series to clean.
55
+
56
+ Returns:
57
+ pd.Series: A new Series with the values cleaned and standardized.
58
+ """
59
+ return series.astype(str).replace(self.rules, regex=True)
60
+
61
+
62
+ class DataFrameCleaner:
63
+ """
64
+ Orchestrates the cleaning of multiple columns in a pandas DataFrame using a nested dictionary of rules and `ColumnCleaner` objects.
65
+
66
+ Args:
67
+ rules (Dict[str, Dict[str, str]]):
68
+ A nested dictionary where each top-level key is a column name,
69
+ and its value is a dictionary of regex rules for that column, as expected by `ColumnCleaner`.
70
+ """
71
+ def __init__(self, rules: Dict[str, Dict[str, str]]):
72
+ if not isinstance(rules, dict):
73
+ raise TypeError("The 'rules' argument must be a nested dictionary.")
74
+
75
+ for col_name, col_rules in rules.items():
76
+ if not isinstance(col_rules, dict):
77
+ raise TypeError(
78
+ f"The value for column '{col_name}' must be a dictionary "
79
+ f"of rules, but got type {type(col_rules).__name__}."
80
+ )
81
+
82
+ self.rules = rules
83
+
84
+ def clean(self, df: pd.DataFrame) -> pd.DataFrame:
85
+ """
86
+ Applies all defined cleaning rules to the DataFrame.
87
+
88
+ Args:
89
+ df (pd.DataFrame): The pandas DataFrame to clean.
90
+
91
+ Returns:
92
+ pd.DataFrame: A new, cleaned DataFrame.
93
+ """
94
+ rule_columns = set(self.rules.keys())
95
+ df_columns = set(df.columns)
96
+
97
+ missing_columns = rule_columns - df_columns
98
+
99
+ if missing_columns:
100
+ # Report all missing columns in a single, clear error message
101
+ raise ValueError(
102
+ f"The following columns specified in the cleaning rules "
103
+ f"were not found in the DataFrame: {sorted(list(missing_columns))}"
104
+ )
105
+
106
+ # Start the process
107
+ df_cleaned = df.copy()
108
+
109
+ for column_name, column_rules in self.rules.items():
110
+ # Create and apply the specific cleaner for the column
111
+ cleaner = ColumnCleaner(rules=column_rules)
112
+ df_cleaned[column_name] = cleaner.clean(df_cleaned[column_name])
113
+
114
+ return df_cleaned
115
+
116
+
117
+ ############ TRANSFORM ####################
118
+
18
119
  # Magic word for rename-only transformation
19
120
  _RENAME = "rename"
20
121
 
@@ -336,8 +437,7 @@ class MultiNumberExtractor:
336
437
  """
337
438
  Extracts multiple numbers from a single polars string column into several new columns.
338
439
 
339
- This transformer is designed for one-to-many mappings, such as parsing
340
- ratios (100:30) or coordinates (10, 25) into separate columns.
440
+ This transformer is designed for one-to-many mappings, such as parsing coordinates (10, 25) into separate columns.
341
441
 
342
442
  Args:
343
443
  num_outputs (int):
@@ -413,6 +513,59 @@ class MultiNumberExtractor:
413
513
  return pl.select(output_expressions)
414
514
 
415
515
 
516
+ class RatioCalculator:
517
+ """
518
+ A transformer that parses a string ratio (e.g., "40:5" or "30/2") and computes the result of the division.
519
+
520
+ Args:
521
+ regex_pattern (str, optional):
522
+ The regex pattern to find the numerator and denominator. It MUST
523
+ contain exactly two capturing groups: the first for the
524
+ numerator and the second for the denominator. Defaults to a
525
+ pattern that handles common delimiters like ':' and '/'.
526
+ """
527
+ def __init__(
528
+ self,
529
+ regex_pattern: str = r"(\d+\.?\d*)\s*[:/]\s*(\d+\.?\d*)"
530
+ ):
531
+ # --- Validation ---
532
+ try:
533
+ if re.compile(regex_pattern).groups != 2:
534
+ raise ValueError(
535
+ "regex_pattern must contain exactly two "
536
+ "capturing groups '(...)'."
537
+ )
538
+ except re.error as e:
539
+ raise ValueError(f"Invalid regex pattern provided: {e}") from e
540
+
541
+ self.regex_pattern = regex_pattern
542
+
543
+ def __call__(self, column: pl.Series) -> pl.Series:
544
+ """
545
+ Applies the ratio calculation logic to the input column.
546
+
547
+ Args:
548
+ column (pl.Series): The input Polars Series of ratio strings.
549
+
550
+ Returns:
551
+ pl.Series: A new Series of floats containing the division result.
552
+ Returns null for invalid formats or division by zero.
553
+ """
554
+ # .extract_groups returns a struct with a field for each capture group
555
+ # e.g., {"group_1": "40", "group_2": "5"}
556
+ groups = column.str.extract_groups(self.regex_pattern)
557
+
558
+ # Extract numerator and denominator, casting to float
559
+ # strict=False ensures that non-matches become null
560
+ numerator = groups.struct.field("group_1").cast(pl.Float64, strict=False)
561
+ denominator = groups.struct.field("group_2").cast(pl.Float64, strict=False)
562
+
563
+ # Safely perform division, returning null if denominator is 0
564
+ return pl.when(denominator != 0).then(
565
+ numerator / denominator
566
+ ).otherwise(None)
567
+
568
+
416
569
  class CategoryMapper:
417
570
  """
418
571
  A transformer that maps string categories to specified numerical values using a dictionary.
@@ -468,6 +621,74 @@ class CategoryMapper:
468
621
  return pl.select(final_expr).to_series()
469
622
 
470
623
 
624
+ class RegexMapper:
625
+ """
626
+ A transformer that maps string categories to numerical values based on a
627
+ dictionary of regular expression patterns.
628
+
629
+ The class iterates through the mapping dictionary in order, and the first
630
+ pattern that matches a given string determines the output value. This
631
+ "first match wins" logic makes the order of the mapping important.
632
+
633
+ Args:
634
+ mapping (Dict[str, Union[int, float]]):
635
+ An ordered dictionary where keys are regex patterns and values are
636
+ the numbers to map to if the pattern is found.
637
+ unseen_value (Optional[Union[int, float]], optional):
638
+ The numerical value to use for strings that do not match any
639
+ of the regex patterns. If None (default), unseen values are
640
+ mapped to null.
641
+ """
642
+ def __init__(
643
+ self,
644
+ mapping: Dict[str, Union[int, float]],
645
+ unseen_value: Optional[Union[int, float]] = None,
646
+ ):
647
+ # --- Validation ---
648
+ if not isinstance(mapping, dict):
649
+ raise TypeError("The 'mapping' argument must be a dictionary.")
650
+
651
+ for pattern, value in mapping.items():
652
+ try:
653
+ re.compile(pattern)
654
+ except re.error as e:
655
+ raise ValueError(f"Invalid regex pattern '{pattern}': {e}") from e
656
+ if not isinstance(value, (int, float)):
657
+ raise TypeError(f"Mapping values must be int or float, but got {type(value)} for pattern '{pattern}'.")
658
+
659
+ self.mapping = mapping
660
+ self.unseen_value = unseen_value
661
+
662
+ def __call__(self, column: pl.Series) -> pl.Series:
663
+ """
664
+ Applies the regex mapping logic to the input column.
665
+
666
+ Args:
667
+ column (pl.Series): The input Polars Series of string data.
668
+
669
+ Returns:
670
+ pl.Series: A new Series with strings mapped to numbers based on
671
+ the first matching regex pattern.
672
+ """
673
+ # Ensure the column is treated as a string for matching
674
+ str_column = column.cast(pl.Utf8)
675
+
676
+ # Build the when/then/otherwise chain from the inside out.
677
+ # Start with the final fallback value for non-matches.
678
+ mapping_expr = pl.lit(self.unseen_value)
679
+
680
+ # Iterate through the mapping in reverse to construct the nested expression
681
+ for pattern, value in reversed(list(self.mapping.items())):
682
+ mapping_expr = (
683
+ pl.when(str_column.str.contains(pattern))
684
+ .then(pl.lit(value))
685
+ .otherwise(mapping_expr)
686
+ )
687
+
688
+ # Execute the complete expression chain and return the resulting Series
689
+ return pl.select(mapping_expr).to_series()
690
+
691
+
471
692
  class ValueBinner:
472
693
  """
473
694
  A transformer that discretizes a continuous numerical column into a finite number of bins.
ml_tools/GUI_tools.py ADDED
@@ -0,0 +1,496 @@
1
+ import configparser
2
+ from pathlib import Path
3
+ from typing import Optional, Callable, Any
4
+ import traceback
5
+ import FreeSimpleGUI as sg
6
+ from functools import wraps
7
+ from typing import Any, Dict, Tuple, List
8
+ from .utilities import _script_info
9
+ import numpy as np
10
+
11
+
12
+ __all__ = [
13
+ "PathManager",
14
+ "ConfigManager",
15
+ "GUIFactory",
16
+ "catch_exceptions",
17
+ "prepare_feature_vector",
18
+ "update_target_fields"
19
+ ]
20
+
21
+
22
+ # --- Path Management ---
23
+ class PathManager:
24
+ """
25
+ Manages paths for a Python application, supporting both development mode and bundled mode via Briefcase.
26
+ """
27
+ def __init__(self, anchor_file: str):
28
+ """
29
+ Initializes the PathManager. The package name is automatically inferred
30
+ from the parent directory of the anchor file.
31
+
32
+ Args:
33
+ anchor_file (str): The absolute path to a file within the project's
34
+ package, typically `__file__` from a module inside
35
+ that package (paths.py).
36
+
37
+ Note:
38
+ This inference assumes that the anchor file's parent directory
39
+ has the same name as the package (e.g., `.../src/my_app/paths.py`).
40
+ This is a standard and recommended project structure.
41
+ """
42
+ resolved_anchor_path = Path(anchor_file).resolve()
43
+ self.package_name = resolved_anchor_path.parent.name
44
+ self._is_bundled, self._resource_path_func = self._check_bundle_status()
45
+
46
+ if self._is_bundled:
47
+ # In a Briefcase bundle, resource_path gives an absolute path
48
+ # to the resource directory.
49
+ self.package_root = self._resource_path_func(self.package_name, "")
50
+ else:
51
+ # In development mode, the package root is the directory
52
+ # containing the anchor file.
53
+ self.package_root = resolved_anchor_path.parent
54
+
55
+ def _check_bundle_status(self) -> tuple[bool, Optional[Callable]]:
56
+ """Checks if the app is running in a bundled environment."""
57
+ try:
58
+ # This is the function Briefcase provides in a bundled app
59
+ from briefcase.platforms.base import resource_path
60
+ return True, resource_path
61
+ except ImportError:
62
+ return False, None
63
+
64
+ def get_path(self, relative_path: str | Path) -> Path:
65
+ """
66
+ Gets the absolute path for a given resource file or directory
67
+ relative to the package root.
68
+
69
+ Args:
70
+ relative_path (str | Path): The path relative to the package root (e.g., 'helpers/icon.png').
71
+
72
+ Returns:
73
+ Path: The absolute path to the resource.
74
+ """
75
+ if self._is_bundled:
76
+ # Briefcase's resource_path handles resolving the path within the app bundle
77
+ return self._resource_path_func(self.package_name, str(relative_path)) # type: ignore
78
+ else:
79
+ # In dev mode, join package root with the relative path.
80
+ return self.package_root / relative_path
81
+
82
+
83
+ # --- Configuration Management ---
84
+ class _SectionProxy:
85
+ """A helper class to represent a section of the .ini file as an object."""
86
+ def __init__(self, parser: configparser.ConfigParser, section_name: str):
87
+ for option, value in parser.items(section_name):
88
+ setattr(self, option.lower(), self._process_value(value))
89
+
90
+ def _process_value(self, value_str: str) -> Any:
91
+ """Automatically converts string values to appropriate types."""
92
+ # Handle None
93
+ if value_str is None or value_str.lower() == 'none':
94
+ return None
95
+ # Handle Booleans
96
+ if value_str.lower() in ['true', 'yes', 'on']:
97
+ return True
98
+ if value_str.lower() in ['false', 'no', 'off']:
99
+ return False
100
+ # Handle Integers
101
+ try:
102
+ return int(value_str)
103
+ except ValueError:
104
+ pass
105
+ # Handle Floats
106
+ try:
107
+ return float(value_str)
108
+ except ValueError:
109
+ pass
110
+ # Handle 'width,height' tuples
111
+ if ',' in value_str:
112
+ try:
113
+ return tuple(map(int, value_str.split(",")))
114
+ except (ValueError, TypeError):
115
+ pass
116
+ # Fallback to the original string
117
+ return value_str
118
+
119
+ class ConfigManager:
120
+ """
121
+ Loads a .ini file and provides access to its values as object attributes.
122
+ Includes a method to generate a default configuration template.
123
+ """
124
+ def __init__(self, config_path: str | Path):
125
+ """
126
+ Initializes the ConfigManager and dynamically creates attributes
127
+ based on the .ini file's sections and options.
128
+ """
129
+ config_path = Path(config_path)
130
+ if not config_path.exists():
131
+ raise FileNotFoundError(f"Configuration file not found at: {config_path}")
132
+
133
+ parser = configparser.ConfigParser(comment_prefixes=('#', ';'), inline_comment_prefixes=('#', ';'))
134
+ parser.read(config_path)
135
+
136
+ for section in parser.sections():
137
+ setattr(self, section.lower(), _SectionProxy(parser, section))
138
+
139
+ @staticmethod
140
+ def generate_template(file_path: str | Path, force_overwrite: bool = False):
141
+ """
142
+ Generates a complete, commented .ini template file that works with the GUIFactory.
143
+
144
+ Args:
145
+ file_path (str | Path): The path where the .ini file will be saved.
146
+ force_overwrite (bool): If True, overwrites the file if it already exists.
147
+ """
148
+ path = Path(file_path)
149
+ if path.exists() and not force_overwrite:
150
+ print(f"Configuration file already exists at {path}. Aborting.")
151
+ return
152
+
153
+ config = configparser.ConfigParser()
154
+
155
+ config['General'] = {
156
+ '; The overall theme for the GUI. Find more at https://www.pysimplegui.org/en/latest/call%20reference/#themes-automatic-coloring-of-elements': '',
157
+ 'theme': 'LightGreen6',
158
+ '; Default font for the application.': '',
159
+ 'font_family': 'Helvetica',
160
+ '; Title of the main window.': '',
161
+ 'window_title': 'My Application',
162
+ '; Can the user resize the window? (true/false)': '',
163
+ 'resizable_window': 'false',
164
+ '; Optional minimum window size (width,height). Leave blank for no minimum.': '',
165
+ 'min_size': '800,600',
166
+ '; Optional maximum window size (width,height). Leave blank for no maximum.': '',
167
+ 'max_size': ''
168
+ }
169
+ config['Layout'] = {
170
+ '; Default size for continuous input boxes (width,height in characters).': '',
171
+ 'input_size_cont': '16,1',
172
+ '; Default size for combo/binary boxes (width,height in characters).': '',
173
+ 'input_size_binary': '14,1',
174
+ '; Default size for buttons (width,height in characters).': '',
175
+ 'button_size': '15,2'
176
+ }
177
+ config['Fonts'] = {
178
+ '; Font settings. Style can be "bold", "italic", "underline", or a combination.': '',
179
+ 'label_size': '11',
180
+ 'label_style': 'bold',
181
+ 'range_size': '9',
182
+ 'range_style': '',
183
+ 'button_size': '14',
184
+ 'button_style': 'bold',
185
+ 'frame_size': '14',
186
+ 'frame_style': ''
187
+ }
188
+ config['Colors'] = {
189
+ '; Use standard hex codes (e.g., #FFFFFF) or color names (e.g., white).': '',
190
+ '; Color for the text inside a disabled target/output box.': '',
191
+ 'target_text': '#0000D0',
192
+ '; Background color for a disabled target/output box.': '',
193
+ 'target_background': '#E0E0E0',
194
+ '; Color for the text on a button.': '',
195
+ 'button_text': '#FFFFFF',
196
+ '; Background color for a button.': '',
197
+ 'button_background': '#3c8a7e',
198
+ '; Background color when the mouse is over a button.': '',
199
+ 'button_background_hover': '#5499C7'
200
+ }
201
+ config['Meta'] = {
202
+ '; Optional application version, displayed in the window title.': '',
203
+ 'version': '1.0.0'
204
+ }
205
+
206
+ with open(path, 'w') as configfile:
207
+ config.write(configfile)
208
+ print(f"Successfully generated config template at: '{path}'")
209
+
210
+
211
+ # --- GUI Factory ---
212
+ class GUIFactory:
213
+ """
214
+ Builds styled FreeSimpleGUI elements and layouts using a "building block"
215
+ approach, driven by a ConfigManager instance.
216
+ """
217
+ def __init__(self, config: ConfigManager):
218
+ """
219
+ Initializes the factory with a configuration object.
220
+ """
221
+ self.config = config
222
+ sg.theme(self.config.general.theme)
223
+ sg.set_options(font=(self.config.general.font_family, 12))
224
+
225
+ # --- Atomic Element Generators ---
226
+ def make_button(self, text: str, key: str, **kwargs) -> sg.Button:
227
+ """
228
+ Creates a single, styled action button.
229
+
230
+ Args:
231
+ text (str): The text displayed on the button.
232
+ key (str): The key for the button element.
233
+ **kwargs: Override default styles or add other sg.Button parameters
234
+ (e.g., `tooltip='Click me'`, `disabled=True`).
235
+ """
236
+ cfg = self.config
237
+ font = (cfg.fonts.font_family, cfg.fonts.button_size, cfg.fonts.button_style)
238
+
239
+ style_args = {
240
+ "size": cfg.layout.button_size,
241
+ "font": font,
242
+ "button_color": (cfg.colors.button_text, cfg.colors.button_background),
243
+ "mouseover_colors": (cfg.colors.button_text, cfg.colors.button_background_hover),
244
+ "border_width": 0,
245
+ **kwargs
246
+ }
247
+ return sg.Button(text.title(), key=key, **style_args)
248
+
249
+ def make_frame(self, title: str, layout: List[List[sg.Element]], **kwargs) -> sg.Frame:
250
+ """
251
+ Creates a styled frame around a given layout.
252
+
253
+ Args:
254
+ title (str): The title displayed on the frame's border.
255
+ layout (list): The layout to enclose within the frame.
256
+ **kwargs: Override default styles or add other sg.Frame parameters
257
+ (e.g., `title_color='red'`, `relief=sg.RELIEF_SUNKEN`).
258
+ """
259
+ cfg = self.config
260
+ font = (cfg.fonts.font_family, cfg.fonts.frame_size)
261
+
262
+ style_args = {
263
+ "font": font,
264
+ "expand_x": True,
265
+ "background_color": sg.theme_background_color(),
266
+ **kwargs
267
+ }
268
+ return sg.Frame(title, layout, **style_args)
269
+
270
+ # --- General-Purpose Layout Generators ---
271
+ def generate_continuous_layout(
272
+ self,
273
+ data_dict: Dict[str, Tuple[float, float]],
274
+ is_target: bool = False,
275
+ layout_mode: str = 'grid',
276
+ columns_per_row: int = 4
277
+ ) -> List[List[sg.Column]]:
278
+ """
279
+ Generates a layout for continuous features or targets.
280
+
281
+ Args:
282
+ data_dict (dict): Keys are feature names, values are (min, max) tuples.
283
+ is_target (bool): If True, creates disabled inputs for displaying results.
284
+ layout_mode (str): 'grid' for a multi-row grid layout, or 'row' for a single horizontal row.
285
+ columns_per_row (int): Number of feature columns per row when layout_mode is 'grid'.
286
+
287
+ Returns:
288
+ A list of lists of sg.Column elements, ready to be used in a window layout.
289
+ """
290
+ cfg = self.config
291
+ bg_color = sg.theme_background_color()
292
+ label_font = (cfg.fonts.font_family, cfg.fonts.label_size, cfg.fonts.label_style)
293
+
294
+ columns = []
295
+ for name, (val_min, val_max) in data_dict.items():
296
+ key = f"TARGET_{name}" if is_target else name
297
+ default_text = "" if is_target else str(val_max)
298
+
299
+ label = sg.Text(name, font=label_font, background_color=bg_color, key=f"_text_{name}")
300
+
301
+ input_style = {"size": cfg.layout.input_size_cont, "justification": "center"}
302
+ if is_target:
303
+ input_style["text_color"] = cfg.colors.target_text
304
+ input_style["disabled_readonly_background_color"] = cfg.colors.target_background
305
+
306
+ element = sg.Input(default_text, key=key, disabled=is_target, **input_style)
307
+
308
+ if is_target:
309
+ layout = [[label], [element]]
310
+ else:
311
+ range_font = (cfg.fonts.font_family, cfg.fonts.range_size)
312
+ range_text = sg.Text(f"Range: {int(val_min)}-{int(val_max)}", font=range_font, background_color=bg_color)
313
+ layout = [[label], [element], [range_text]]
314
+
315
+ layout.append([sg.Text(" ", font=(cfg.fonts.font_family, 2), background_color=bg_color)])
316
+ columns.append(sg.Column(layout, background_color=bg_color))
317
+
318
+ if layout_mode == 'row':
319
+ return [columns] # A single row containing all columns
320
+
321
+ # Default to 'grid' layout
322
+ return [columns[i:i + columns_per_row] for i in range(0, len(columns), columns_per_row)]
323
+
324
+ def generate_combo_layout(
325
+ self,
326
+ data_dict: Dict[str, List[Any]],
327
+ layout_mode: str = 'grid',
328
+ columns_per_row: int = 4
329
+ ) -> List[List[sg.Column]]:
330
+ """
331
+ Generates a layout for categorical or binary features using Combo boxes.
332
+
333
+ Args:
334
+ data_dict (dict): Keys are feature names, values are lists of options.
335
+ layout_mode (str): 'grid' for a multi-row grid layout, or 'row' for a single horizontal row.
336
+ columns_per_row (int): Number of feature columns per row when layout_mode is 'grid'.
337
+
338
+ Returns:
339
+ A list of lists of sg.Column elements, ready to be used in a window layout.
340
+ """
341
+ cfg = self.config
342
+ bg_color = sg.theme_background_color()
343
+ label_font = (cfg.fonts.font_family, cfg.fonts.label_size, cfg.fonts.label_style)
344
+
345
+ columns = []
346
+ for name, values in data_dict.items():
347
+ label = sg.Text(name, font=label_font, background_color=bg_color, key=f"_text_{name}")
348
+ element = sg.Combo(
349
+ values, default_value=values[0], key=name,
350
+ size=cfg.layout.input_size_binary, readonly=True
351
+ )
352
+ layout = [[label], [element]]
353
+ layout.append([sg.Text(" ", font=(cfg.fonts.font_family, 2), background_color=bg_color)])
354
+ columns.append(sg.Column(layout, background_color=bg_color))
355
+
356
+ if layout_mode == 'row':
357
+ return [columns] # A single row containing all columns
358
+
359
+ # Default to 'grid' layout
360
+ return [columns[i:i + columns_per_row] for i in range(0, len(columns), columns_per_row)]
361
+
362
+ # --- Window Creation ---
363
+ def create_window(self, title: str, layout: List[List[sg.Element]], **kwargs) -> sg.Window:
364
+ """
365
+ Creates and finalizes the main application window.
366
+
367
+ Args:
368
+ title (str): The title for the window.
369
+ layout (list): The final, assembled layout for the window.
370
+ **kwargs: Additional arguments to pass to the sg.Window constructor
371
+ (e.g., `location=(100, 100)`, `keep_on_top=True`).
372
+ """
373
+ cfg = self.config.general
374
+ version = getattr(self.config.meta, 'version', None)
375
+ full_title = f"{title} v{version}" if version else title
376
+
377
+ window_args = {
378
+ "resizable": cfg.resizable_window,
379
+ "finalize": True,
380
+ "background_color": sg.theme_background_color(),
381
+ **kwargs
382
+ }
383
+ window = sg.Window(full_title, layout, **window_args)
384
+
385
+ if cfg.min_size: window.TKroot.minsize(*cfg.min_size)
386
+ if cfg.max_size: window.TKroot.maxsize(*cfg.max_size)
387
+
388
+ return window
389
+
390
+
391
+ # --- Exception Handling Decorator ---
392
+ def catch_exceptions(show_popup: bool = True):
393
+ """
394
+ A decorator that wraps a function in a try-except block.
395
+ If an exception occurs, it's caught and displayed in a popup window.
396
+ """
397
+ def decorator(func):
398
+ @wraps(func)
399
+ def wrapper(*args, **kwargs):
400
+ try:
401
+ return func(*args, **kwargs)
402
+ except Exception as e:
403
+ # Format the full traceback to give detailed error info
404
+ error_msg = traceback.format_exc()
405
+ if show_popup:
406
+ sg.popup_error("An error occurred:", error_msg, title="Error")
407
+ else:
408
+ # Fallback for non-GUI contexts or if popup is disabled
409
+ print("--- An exception occurred ---")
410
+ print(error_msg)
411
+ print("-----------------------------")
412
+ return wrapper
413
+ return decorator
414
+
415
+
416
+ # --- Inference Helpers ---
417
+ def _default_categorical_processor(feature_name: str, chosen_value: Any) -> List[float]:
418
+ """
419
+ Default processor for binary 'True'/'False' strings.
420
+ Returns a list containing a single float.
421
+ """
422
+ return [1.0] if str(chosen_value) == 'True' else [0.0]
423
+
424
+ def prepare_feature_vector(
425
+ values: Dict[str, Any],
426
+ feature_order: List[str],
427
+ continuous_features: List[str],
428
+ categorical_features: List[str],
429
+ categorical_processor: Optional[Callable[[str, Any], List[float]]] = None
430
+ ) -> np.ndarray:
431
+ """
432
+ Validates and converts GUI values into a numpy array for a model.
433
+ This function supports label encoding and one-hot encoding via the processor.
434
+
435
+ Args:
436
+ values (dict): The values dictionary from a `window.read()` call.
437
+ feature_order (list): A list of all feature names that have a GUI element.
438
+ For one-hot encoding, this should be the name of the
439
+ single GUI element (e.g., 'material_type'), not the
440
+ expanded feature names (e.g., 'material_is_steel').
441
+ continuous_features (list): A list of names for continuous features.
442
+ categorical_features (list): A list of names for categorical features.
443
+ categorical_processor (callable, optional): A function to process categorical
444
+ values. It should accept (feature_name, chosen_value) and return a
445
+ list of floats (e.g., [1.0] for label encoding, [0.0, 1.0, 0.0] for one-hot).
446
+ If None, a default 'True'/'False' processor is used.
447
+
448
+ Returns:
449
+ A 1D numpy array ready for model inference.
450
+ """
451
+ processed_values: List[float] = []
452
+
453
+ # Use the provided processor or the default one
454
+ processor = categorical_processor or _default_categorical_processor
455
+
456
+ # Create sets for faster lookups
457
+ cont_set = set(continuous_features)
458
+ cat_set = set(categorical_features)
459
+
460
+ for name in feature_order:
461
+ chosen_value = values.get(name)
462
+
463
+ if chosen_value is None or chosen_value == '':
464
+ raise ValueError(f"Feature '{name}' is missing a value.")
465
+
466
+ if name in cont_set:
467
+ try:
468
+ processed_values.append(float(chosen_value))
469
+ except (ValueError, TypeError):
470
+ raise ValueError(f"Invalid input for '{name}'. Please enter a valid number.")
471
+
472
+ elif name in cat_set:
473
+ # The processor returns a list of values (one for label, multiple for one-hot)
474
+ numeric_values = processor(name, chosen_value)
475
+ processed_values.extend(numeric_values)
476
+
477
+ return np.array(processed_values, dtype=np.float32)
478
+
479
+
480
+ def update_target_fields(window: sg.Window, results_dict: Dict[str, Any]):
481
+ """
482
+ Updates the GUI's target fields with inference results.
483
+
484
+ Args:
485
+ window (sg.Window): The application's window object.
486
+ results_dict (dict): A dictionary where keys are target names (without the
487
+ 'TARGET_' prefix) and values are the predicted results.
488
+ """
489
+ for target_name, result in results_dict.items():
490
+ # Format numbers to 2 decimal places, leave other types as-is
491
+ display_value = f"{result:.2f}" if isinstance(result, (int, float)) else result
492
+ window[f'TARGET_{target_name}'].update(display_value)
493
+
494
+
495
+ def info():
496
+ _script_info(__all__)
@@ -7,15 +7,27 @@ from sklearn.base import ClassifierMixin
7
7
  from typing import Literal, Union, Tuple, Dict, Optional
8
8
  import pandas as pd
9
9
  from copy import deepcopy
10
- from .utilities import _script_info, threshold_binary_values, threshold_binary_values_batch, deserialize_object, list_files_by_extension, save_dataframe, make_fullpath
10
+ from .utilities import _script_info, threshold_binary_values, threshold_binary_values_batch, deserialize_object, list_files_by_extension, save_dataframe, make_fullpath, yield_dataframes_from_dir, sanitize_filename
11
11
  import torch
12
12
  from tqdm import trange
13
+ import logging
14
+ import matplotlib.pyplot as plt
15
+ import seaborn as sns
16
+ from collections import defaultdict
17
+
18
+ # Configure logger
19
+ logging.basicConfig(
20
+ level=logging.INFO,
21
+ format="[%(asctime)s] [%(levelname)s] - %(message)s",
22
+ datefmt="%Y-%m-%d %H:%M:%S"
23
+ )
13
24
 
14
25
 
15
26
  __all__ = [
16
27
  "ObjectiveFunction",
17
28
  "multiple_objective_functions_from_dir",
18
- "run_pso"
29
+ "run_pso",
30
+ "plot_optimal_feature_distributions"
19
31
  ]
20
32
 
21
33
 
@@ -184,6 +196,52 @@ def _save_results(*dicts, save_dir: Union[str,Path], target_name: str):
184
196
  save_dataframe(df=df, save_dir=save_dir, filename=f"Optimization_{target_name}")
185
197
 
186
198
 
199
+ def _run_single_pso(objective_function: ObjectiveFunction, pso_args: dict, feature_names: list[str], target_name: str, random_state: int):
200
+ """Helper for a single PSO run."""
201
+ pso_args.update({"seed": random_state})
202
+
203
+ best_features, best_target, *_ = _pso(**pso_args)
204
+
205
+ # Flip best_target if maximization was used
206
+ if objective_function.task == "maximization":
207
+ best_target = -best_target
208
+
209
+ # Threshold binary features
210
+ binary_number = objective_function.binary_features
211
+ best_features_threshold = threshold_binary_values(best_features, binary_number)
212
+
213
+ # Name features and target
214
+ best_features_named = {name: value for name, value in zip(feature_names, best_features_threshold)}
215
+ best_target_named = {target_name: best_target}
216
+
217
+ return best_features_named, best_target_named
218
+
219
+
220
+ def _run_post_hoc_pso(objective_function: ObjectiveFunction, pso_args: dict, feature_names: list[str], target_name: str, repetitions: int):
221
+ """Helper for post-hoc PSO analysis."""
222
+ all_best_targets = []
223
+ all_best_features = [[] for _ in range(len(feature_names))]
224
+
225
+ for _ in range(repetitions):
226
+ best_features, best_target, *_ = _pso(**pso_args)
227
+
228
+ if objective_function.task == "maximization":
229
+ best_target = -best_target
230
+
231
+ binary_number = objective_function.binary_features
232
+ best_features_threshold = threshold_binary_values(best_features, binary_number)
233
+
234
+ for i, best_feature in enumerate(best_features_threshold):
235
+ all_best_features[i].append(best_feature)
236
+ all_best_targets.append(best_target)
237
+
238
+ # Name features and target
239
+ all_best_features_named = {name: lst for name, lst in zip(feature_names, all_best_features)}
240
+ all_best_targets_named = {target_name: all_best_targets}
241
+
242
+ return all_best_features_named, all_best_targets_named
243
+
244
+
187
245
  def run_pso(lower_boundaries: list[float],
188
246
  upper_boundaries: list[float],
189
247
  objective_function: ObjectiveFunction,
@@ -236,6 +294,8 @@ def run_pso(lower_boundaries: list[float],
236
294
  -----
237
295
  - PSO minimizes the objective function by default; if maximization is desired, it should be handled inside the ObjectiveFunction.
238
296
  """
297
+
298
+
239
299
  # Select device
240
300
  if torch.cuda.is_available():
241
301
  device = torch.device("cuda")
@@ -243,7 +303,8 @@ def run_pso(lower_boundaries: list[float],
243
303
  device = torch.device("mps")
244
304
  else:
245
305
  device = torch.device("cpu")
246
- print(f"[PSO] Using device: '{device}'")
306
+
307
+ logging.info(f"Using device: '{device}'")
247
308
 
248
309
  # set local deep copies to prevent in place list modification
249
310
  local_lower_boundaries = deepcopy(lower_boundaries)
@@ -271,7 +332,7 @@ def run_pso(lower_boundaries: list[float],
271
332
  if target_name is None:
272
333
  target_name = "Target"
273
334
 
274
- arguments = {
335
+ pso_arguments = {
275
336
  "func":objective_function,
276
337
  "lb": lower,
277
338
  "ub": upper,
@@ -281,59 +342,17 @@ def run_pso(lower_boundaries: list[float],
281
342
  "particle_output": False,
282
343
  }
283
344
 
345
+ # Dispatcher
346
+ if post_hoc_analysis is None or post_hoc_analysis <= 1:
347
+ features, target = _run_single_pso(objective_function, pso_arguments, names, target_name, random_state)
348
+ else:
349
+ features, target = _run_post_hoc_pso(objective_function, pso_arguments, names, target_name, post_hoc_analysis)
350
+
351
+ # --- Save Results ---
284
352
  save_results_path = make_fullpath(save_results_dir, make=True)
353
+ _save_results(features, target, save_dir=save_results_path, target_name=target_name)
285
354
 
286
- if post_hoc_analysis is None or post_hoc_analysis == 1:
287
- arguments.update({"seed": random_state})
288
-
289
- best_features, best_target, *_ = _pso(**arguments)
290
- # best_features, best_target, _particle_positions, _target_values_per_position = _pso(**arguments)
291
-
292
- # flip best_target if maximization was used
293
- if objective_function.task == "maximization":
294
- best_target = -best_target
295
-
296
- # threshold binary features
297
- best_features_threshold = threshold_binary_values(best_features, binary_number)
298
-
299
- # name features
300
- best_features_named = {name: value for name, value in zip(names, best_features_threshold)}
301
- best_target_named = {target_name: best_target}
302
-
303
- # save results
304
- _save_results(best_features_named, best_target_named, save_dir=save_results_path, target_name=target_name)
305
-
306
- return best_features_named, best_target_named
307
- else:
308
- all_best_targets = list()
309
- all_best_features = [[] for _ in range(size_of_features)]
310
- for _ in range(post_hoc_analysis):
311
- best_features, best_target, *_ = _pso(**arguments)
312
- # best_features, best_target, _particle_positions, _target_values_per_position = _pso(**arguments)
313
-
314
- # flip best_target if maximization was used
315
- if objective_function.task == "maximization":
316
- best_target = -best_target
317
-
318
- # threshold binary features
319
- best_features_threshold = threshold_binary_values(best_features, binary_number)
320
-
321
- for i, best_feature in enumerate(best_features_threshold):
322
- all_best_features[i].append(best_feature)
323
- all_best_targets.append(best_target)
324
-
325
- # name features
326
- all_best_features_named = {name: list_values for name, list_values in zip(names, all_best_features)}
327
- all_best_targets_named = {target_name: all_best_targets}
328
-
329
- # save results
330
- _save_results(all_best_features_named, all_best_targets_named, save_dir=save_results_path, target_name=target_name)
331
-
332
- return all_best_features_named, all_best_targets_named # type: ignore
333
-
334
-
335
- def info():
336
- _script_info(__all__)
355
+ return features, target
337
356
 
338
357
 
339
358
  def _pso(func: ObjectiveFunction,
@@ -342,7 +361,9 @@ def _pso(func: ObjectiveFunction,
342
361
  device: torch.device,
343
362
  swarmsize: int,
344
363
  maxiter: int,
345
- omega = 0.729, # Clerc and Kennedy’s constriction coefficient
364
+ omega_start = 0.9, # STARTING inertia weight
365
+ omega_end = 0.4, # ENDING inertia weight
366
+ # omega = 0.729, # Clerc and Kennedy’s constriction coefficient
346
367
  phip = 1.49445, # Clerc and Kennedy’s constriction coefficient
347
368
  phig = 1.49445, # Clerc and Kennedy’s constriction coefficient
348
369
  tolerance = 1e-8,
@@ -418,7 +439,7 @@ def _pso(func: ObjectiveFunction,
418
439
 
419
440
  # Initialize positions and velocities
420
441
  r = torch.rand((swarmsize, ndim), device=device, requires_grad=False)
421
- positions = lb_t + r * (ub_t - lb_t) # shape: (swarmsize, ndim)
442
+ positions = lb_t + r * (ub_t - lb_t)
422
443
  velocities = torch.zeros_like(positions, requires_grad=False)
423
444
 
424
445
  # Initialize best positions and scores
@@ -428,19 +449,17 @@ def _pso(func: ObjectiveFunction,
428
449
  global_best_score = float('inf')
429
450
  global_best_position = torch.zeros(ndim, device=device, requires_grad=False)
430
451
 
431
- # History (optional)
432
452
  if particle_output:
433
453
  history_positions = []
434
454
  history_scores = []
435
455
 
436
- # Main loop
437
456
  previous_best_score = float('inf')
438
- progress = trange(maxiter, desc="PSO", unit="iter", leave=True) #tqdm bar
457
+ progress = trange(maxiter, desc="PSO", unit="iter", leave=True)
439
458
  with torch.no_grad():
440
459
  for i in progress:
441
460
  # Evaluate objective for all particles
442
- positions_np = positions.detach().cpu().numpy() # shape: (swarmsize, n_features)
443
- scores_np = func(positions_np) # shape: (swarmsize,)
461
+ positions_np = positions.detach().cpu().numpy()
462
+ scores_np = func(positions_np)
444
463
  scores = torch.tensor(scores_np, device=device, dtype=torch.float32)
445
464
 
446
465
  # Update personal bests
@@ -454,17 +473,18 @@ def _pso(func: ObjectiveFunction,
454
473
  global_best_score = min_score.item()
455
474
  global_best_position = personal_best_positions[min_idx].clone()
456
475
 
457
- # Early stopping criteria
458
476
  if abs(previous_best_score - global_best_score) < tolerance:
459
477
  progress.set_description(f"PSO (early stop at iteration {i+1})")
460
478
  break
461
479
  previous_best_score = global_best_score
462
480
 
463
- # Optional: track history for debugging/visualization
464
481
  if particle_output:
465
482
  history_positions.append(positions.detach().cpu().numpy())
466
483
  history_scores.append(scores_np)
467
-
484
+
485
+ # Linearly decreasing inertia weight
486
+ omega = omega_start - (omega_start - omega_end) * (i / maxiter)
487
+
468
488
  # Velocity update
469
489
  rp = torch.rand((swarmsize, ndim), device=device, requires_grad=False)
470
490
  rg = torch.rand((swarmsize, ndim), device=device, requires_grad=False)
@@ -476,11 +496,9 @@ def _pso(func: ObjectiveFunction,
476
496
  # Position update
477
497
  positions = positions + velocities
478
498
 
479
- # Clamp to search space bounds
480
499
  positions = torch.max(positions, lb_t)
481
500
  positions = torch.min(positions, ub_t)
482
501
 
483
- # Move to CPU and convert to NumPy
484
502
  best_position = global_best_position.detach().cpu().numpy()
485
503
  best_score = global_best_score
486
504
 
@@ -488,3 +506,91 @@ def _pso(func: ObjectiveFunction,
488
506
  return best_position, best_score, history_positions, history_scores
489
507
  else:
490
508
  return best_position, best_score
509
+
510
+
511
+ def plot_optimal_feature_distributions(results_dir: Union[str, Path], save_dir: Union[str, Path], color_by_target: bool = True):
512
+ """
513
+ Analyzes optimization results and plots the distribution of optimal values for each feature.
514
+
515
+ This function can operate in two modes based on the `color_by_target` parameter:
516
+ 1. Aggregates all values for a feature into a single group and plots one overall distribution (histogram + KDE).
517
+ 2. Color-coded: Plots a separate, color-coded Kernel Density Estimate (KDE) for each source target, allowing for direct comparison on a single chart.
518
+
519
+ Parameters
520
+ ----------
521
+ results_dir : str or Path
522
+ The path to the directory containing the optimization result CSV files.
523
+ save_dir : str or Path
524
+ The directory where the output plots will be saved.
525
+ color_by_target : bool, optional
526
+ If True, generates comparative plots with distributions colored by their source target.
527
+ """
528
+ mode = "Comparative (color-coded)" if color_by_target else "Aggregate"
529
+ logging.info(f"Starting analysis in '{mode}' mode from results in: '{results_dir}'")
530
+
531
+ output_path = make_fullpath(save_dir, make=True)
532
+ all_files = list(yield_dataframes_from_dir(results_dir))
533
+
534
+ if not all_files:
535
+ logging.warning("No data found. No plots will be generated.")
536
+ return
537
+
538
+ # --- MODE 1: Color-coded plots by target ---
539
+ if color_by_target:
540
+ data_to_plot = []
541
+ for df, df_name in all_files:
542
+ # Assumes last col is target, rest are features
543
+ melted_df = df.iloc[:, :-1].melt(var_name='feature', value_name='value')
544
+ # Sanitize target name for cleaner legend labels
545
+ melted_df['target'] = df_name.replace("Optimization_", "")
546
+ data_to_plot.append(melted_df)
547
+
548
+ long_df = pd.concat(data_to_plot, ignore_index=True)
549
+ features = long_df['feature'].unique()
550
+ logging.info(f"Found data for {len(features)} features across {len(long_df['target'].unique())} targets. Generating plots...")
551
+
552
+ for feature_name in features:
553
+ plt.figure(figsize=(12, 7))
554
+ feature_df = long_df[long_df['feature'] == feature_name]
555
+
556
+ sns.kdeplot(data=feature_df, x='value', hue='target', fill=True, alpha=0.1)
557
+
558
+ plt.title(f"Comparative Distribution for '{feature_name}'", fontsize=16)
559
+ plt.xlabel("Feature Value", fontsize=12)
560
+ plt.ylabel("Density", fontsize=12)
561
+ plt.grid(axis='y', alpha=0.5, linestyle='--')
562
+ plt.legend(title='Target')
563
+
564
+ sanitized_feature_name = sanitize_filename(feature_name)
565
+ plot_filename = output_path / f"Comparative_{sanitized_feature_name}.svg"
566
+ plt.savefig(plot_filename, bbox_inches='tight')
567
+ plt.close()
568
+
569
+ # --- MODE 2: Aggregate plot ---
570
+ else:
571
+ feature_distributions = defaultdict(list)
572
+ for df, _ in all_files:
573
+ feature_columns = df.iloc[:, :-1]
574
+ for feature_name in feature_columns:
575
+ feature_distributions[feature_name].extend(df[feature_name].tolist())
576
+
577
+ logging.info(f"Found data for {len(feature_distributions)} features. Generating plots...")
578
+ for feature_name, values in feature_distributions.items():
579
+ plt.figure(figsize=(12, 7))
580
+ sns.histplot(x=values, kde=True, bins='auto', stat="density")
581
+
582
+ plt.title(f"Aggregate Distribution for '{feature_name}'", fontsize=16)
583
+ plt.xlabel("Feature Value", fontsize=12)
584
+ plt.ylabel("Density", fontsize=12)
585
+ plt.grid(axis='y', alpha=0.5, linestyle='--')
586
+
587
+ sanitized_feature_name = sanitize_filename(feature_name)
588
+ plot_filename = output_path / f"Aggregate_{sanitized_feature_name}.svg"
589
+ plt.savefig(plot_filename, bbox_inches='tight')
590
+ plt.close()
591
+
592
+ logging.info(f"✅ All plots saved successfully to: {output_path}")
593
+
594
+
595
+ def info():
596
+ _script_info(__all__)
ml_tools/utilities.py CHANGED
@@ -86,7 +86,6 @@ def make_fullpath(
86
86
  return resolved
87
87
 
88
88
 
89
-
90
89
  def list_csv_paths(directory: Union[str,Path]) -> dict[str, Path]:
91
90
  """
92
91
  Lists all `.csv` files in the specified directory and returns a mapping: filenames (without extensions) to their absolute paths.