dragon-ml-toolbox 20.4.0__py3-none-any.whl → 20.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 20.4.0
3
+ Version: 20.6.0
4
4
  Summary: Complete pipelines and helper tools for data science and machine learning projects.
5
5
  Author-email: Karl Luigi Loza Vidaurre <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -1,11 +1,11 @@
1
- dragon_ml_toolbox-20.4.0.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
2
- dragon_ml_toolbox-20.4.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=0-HBRMMgKuwtGy6nMJZvIn1fLxhx_ksyyVB2U_iyYZU,2818
1
+ dragon_ml_toolbox-20.6.0.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
2
+ dragon_ml_toolbox-20.6.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=0-HBRMMgKuwtGy6nMJZvIn1fLxhx_ksyyVB2U_iyYZU,2818
3
3
  ml_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  ml_tools/constants.py,sha256=3br5Rk9cL2IUo638eJuMOGdbGQaWssaUecYEvSeRBLM,3322
5
- ml_tools/ETL_cleaning/__init__.py,sha256=8dsHiguUkI6Ix1759IPdGU3IXcjMz4DyaSCkdYhxxg8,490
5
+ ml_tools/ETL_cleaning/__init__.py,sha256=gLRHF-qzwpqKTvbbn9chIQELeUDh_XGpBRX28j-5IqI,545
6
6
  ml_tools/ETL_cleaning/_basic_clean.py,sha256=2_FhWP-xYgl8s51H3OjYb_sqsW2yX_QZ4kmyrKjbSsc,13892
7
- ml_tools/ETL_cleaning/_clean_tools.py,sha256=pizTBK69zHt7HpZc_bcX9KoX2loLDcyQJddf_Kl-Ldo,5129
8
- ml_tools/ETL_cleaning/_dragon_cleaner.py,sha256=dge7KQSO4IdeXV4pCCJCb5lhAzR8rmwZPoCscm1A9KY,10272
7
+ ml_tools/ETL_cleaning/_clean_tools.py,sha256=7aIC4w0CLK93E2nWC8h8YbI8bW_3Na9myD9VBMA-9zQ,9575
8
+ ml_tools/ETL_cleaning/_dragon_cleaner.py,sha256=WvDHtdQTQldYwRWkmr3MlqFgWPl8rrEHp6m1uqgH0ho,13291
9
9
  ml_tools/ETL_engineering/__init__.py,sha256=EVIU0skxaH4ZDk8tEkOrxhTMSSA2LI_glhIpzFSxxlg,1007
10
10
  ml_tools/ETL_engineering/_dragon_engineering.py,sha256=D-D6tmhyQ3I9-cXgxLVVbQBRTZoNsWaKPsvcTUaetws,10810
11
11
  ml_tools/ETL_engineering/_transforms.py,sha256=qOxa_vjh3gzS4IiGFqq_0Wnh0ilQO41jRiIp-6Ej4vw,47079
@@ -30,7 +30,7 @@ ml_tools/ML_chain/_update_schema.py,sha256=z1Us7lv6hy6GwSu1mcid50Jmqq3sh91hMQ0Ln
30
30
  ml_tools/ML_configuration/__init__.py,sha256=ogktFnYxz5jWJkhHS4DVaMldHkt3lT2gw9jx5PQ3d78,2755
31
31
  ml_tools/ML_configuration/_base_model_config.py,sha256=95L3IfobNFMtnNr79zYpDGerC1q1v7M05tWZvTS2cwE,2247
32
32
  ml_tools/ML_configuration/_finalize.py,sha256=l_n13bLu0avMdJ8hNRrH8V_wOBQZM1UGsTydKBkTysM,15047
33
- ml_tools/ML_configuration/_metrics.py,sha256=PqBGPO1Y_6ImmYI3TEBJhzipULE854vbvE0AbP5m8zQ,22888
33
+ ml_tools/ML_configuration/_metrics.py,sha256=xKtEKzphtidwwU8UuUpGv4B8Y6Bv0tAOjEFUYfz8Ehc,23758
34
34
  ml_tools/ML_configuration/_models.py,sha256=lvuuqvD6DWUzOa3i06NZfrdfOi9bu2e26T_QO6BGMSw,7629
35
35
  ml_tools/ML_configuration/_training.py,sha256=_M_TwouHFNbGrZQtQNAvyG_poSVpmN99cbyUonZsHhk,8969
36
36
  ml_tools/ML_datasetmaster/__init__.py,sha256=UltQzuXnlXVCkD-aeA5TW4IcMVLnQf1_aglawg4WyrI,580
@@ -39,7 +39,7 @@ ml_tools/ML_datasetmaster/_datasetmaster.py,sha256=Oy2UE3YJpKTaFwQF5TkQLgLB54-BF
39
39
  ml_tools/ML_datasetmaster/_sequence_datasetmaster.py,sha256=cW3fuILZWs-7Yuo4T2fgGfTC4vwho3Gp4ohIKJYS7O0,18452
40
40
  ml_tools/ML_datasetmaster/_vision_datasetmaster.py,sha256=kvSqXYeNBN1JSRfSEEXYeIcsqy9HsJAl_EwFWClqlsw,67025
41
41
  ml_tools/ML_evaluation/__init__.py,sha256=e3c8JNP0tt4Kxc7QSQpGcOgrxf8JAucH4UkJvJxUL2E,1122
42
- ml_tools/ML_evaluation/_classification.py,sha256=xXCh87RE9_VXYalc7l6CbakYfB0rijGrY76RZIrqLBk,28922
42
+ ml_tools/ML_evaluation/_classification.py,sha256=8bKQejKrgMipnxU1T12ted7p60xvJS0d0MvHtdNBCBM,30971
43
43
  ml_tools/ML_evaluation/_feature_importance.py,sha256=mTwi3LKom_axu6UFKunELj30APDdhG9GQC2w7I9mYhI,17137
44
44
  ml_tools/ML_evaluation/_loss.py,sha256=1a4O25i3Ya_3naNZNL7ELLUL46BY86g1scA7d7q2UFM,3625
45
45
  ml_tools/ML_evaluation/_regression.py,sha256=hnT2B2_6AnQ7aA7uk-X2lZL9G5JFGCduDXyZbr1gFCA,11037
@@ -76,7 +76,7 @@ ml_tools/ML_models_vision/_image_classification.py,sha256=miwMNoTXpmmZSiqeXvDKpx
76
76
  ml_tools/ML_models_vision/_image_segmentation.py,sha256=NRjn91bDD2OJWSJFrrNW9s41qgg5w7pw68Q61-kg-As,4157
77
77
  ml_tools/ML_models_vision/_object_detection.py,sha256=AOGER5bx0REc-FfBtspJmyLJxn3GdwDSPwFGveobR94,5608
78
78
  ml_tools/ML_optimization/__init__.py,sha256=No18Dsw6Q9zPt8B9fpG0bWomuXmwDC7DiokiaPuwmRI,485
79
- ml_tools/ML_optimization/_multi_dragon.py,sha256=R0G91Y-TK49coCE0NAZdQuEqI0kTEaKuIuZ6QGE99lg,38525
79
+ ml_tools/ML_optimization/_multi_dragon.py,sha256=zQhDxFY8FNxUlcbSnHMVArfojzYjgNa21jSE3pJmRW0,38956
80
80
  ml_tools/ML_optimization/_single_dragon.py,sha256=jh5-SK6NKAzbheQhquiYoROozk-RzUv1jiFkIzK_AFg,7288
81
81
  ml_tools/ML_optimization/_single_manual.py,sha256=h-_k9JmRqPkjTra1nu7AyYbSyWkYZ1R3utiNmW06WFs,21809
82
82
  ml_tools/ML_scaler/_ML_scaler.py,sha256=P75X0Sx8N-VxC2Qy8aG7mWaZlkTfjspiZDi1YiMQD1I,8872
@@ -103,12 +103,12 @@ ml_tools/_core/__init__.py,sha256=m-VP0RW0tOTm9N5NI3kFNcpM7WtVgs0RK9pK3ZJRZQQ,14
103
103
  ml_tools/_core/_logger.py,sha256=xzhn_FouMDRVNwXGBGlPC9Ruq6i5uCrmNaS5jesguMU,4972
104
104
  ml_tools/_core/_schema_load_ops.py,sha256=KLs9vBzANz5ESe2wlP-C41N4VlgGil-ywcfvWKSOGss,1551
105
105
  ml_tools/_core/_script_info.py,sha256=LtFGt10gEvCnhIRMKJPi2yXkiGLcdr7lE-oIP2XGHzQ,234
106
- ml_tools/data_exploration/__init__.py,sha256=ahCjELrum2aIj_cLK-sdGbJjTvvolf3US_oaB97rOQg,1736
106
+ ml_tools/data_exploration/__init__.py,sha256=nYKg1bPBgXibC5nhmNKPw3VaKFeVtlNGL_YpHixW-Pg,1795
107
107
  ml_tools/data_exploration/_analysis.py,sha256=H6LryV56FFCHWjvQdkhZbtprZy6aP8EqU_hC2Cf9CLE,7832
108
108
  ml_tools/data_exploration/_cleaning.py,sha256=pAZOXgGK35j7O8q6cnyTwYK1GLNnD04A8p2fSyMB1mg,20906
109
109
  ml_tools/data_exploration/_features.py,sha256=wW-M8n2aLIy05DR2z4fI8wjpPjn3mOAnm9aSGYbMKwI,23363
110
110
  ml_tools/data_exploration/_plotting.py,sha256=zH1dPcIoAlOuww23xIoBCsQOAshPPv9OyGposOA2RvI,19883
111
- ml_tools/data_exploration/_schema_ops.py,sha256=PoFeHaS9dXI9gfL0SRD-8uSP4owqmbQFbtfA-HxkLnY,7108
111
+ ml_tools/data_exploration/_schema_ops.py,sha256=Fd6fBGGv4OpxmJ1HG9pith6QL90z0tzssCvzkQxlEEQ,11083
112
112
  ml_tools/ensemble_evaluation/__init__.py,sha256=t4Gr8EGEk8RLatyc92-S0BzbQvdvodzoF-qDAH2qjVg,546
113
113
  ml_tools/ensemble_evaluation/_ensemble_evaluation.py,sha256=-sX9cLMaa0FOQDikmVv2lsCYtQ56Kftd3tILnNej0Hg,28346
114
114
  ml_tools/ensemble_inference/__init__.py,sha256=VMX-Kata2V0UmiURIU2jx6mRuZmvTWf-QXzCpHmVGZA,255
@@ -118,7 +118,7 @@ ml_tools/ensemble_learning/_ensemble_learning.py,sha256=MHDZBR20_nStlSSeThFI3bSu
118
118
  ml_tools/excel_handler/__init__.py,sha256=AaWM3n_dqBhJLTs3OEA57ex5YykKXNOwVCyHlVsdnqI,530
119
119
  ml_tools/excel_handler/_excel_handler.py,sha256=TODudmeQgDSdxUKzLfAzizs--VL-g8WxDOfQ4sgxxLs,13965
120
120
  ml_tools/keys/__init__.py,sha256=-0c2pmrhyfROc-oQpEjJGLBMhSagA3CyFijQaaqZRqU,399
121
- ml_tools/keys/_keys.py,sha256=kBcW3euNmD57_4aoRaAeqJP3FtU3iSuvgYv-BZqnEWw,9290
121
+ ml_tools/keys/_keys.py,sha256=lL9NlijxOEAhfDPPqK_wL3QhjalrYK_fWM-KNniSIOA,9308
122
122
  ml_tools/math_utilities/__init__.py,sha256=K7Obkkc4rPKj4EbRZf1BsXHfiCg7FXYv_aN9Yc2Z_Vg,400
123
123
  ml_tools/math_utilities/_math_utilities.py,sha256=BYHIVcM9tuKIhVrkgLLiM5QalJ39zx7dXYy_M9aGgiM,9012
124
124
  ml_tools/optimization_tools/__init__.py,sha256=KD8JXpfGuPndO4AHnjJGu6uV1GRwhOfboD0KZV45kzw,658
@@ -130,14 +130,14 @@ ml_tools/path_manager/_path_tools.py,sha256=LcZE31QlkzZWUR8g1MW_N_mPY2DpKBJLA45V
130
130
  ml_tools/plot_fonts/__init__.py,sha256=KIxXRCjQ3SliEoLhEcqs7zDVZbVTn38bmSdL-yR1Q2w,187
131
131
  ml_tools/plot_fonts/_plot_fonts.py,sha256=mfjXNT9P59ymHoTI85Q8CcvfxfK5BIFBWtTZH-hNIC4,2209
132
132
  ml_tools/schema/__init__.py,sha256=K6uiZ9f0GCQ7etw1yl2-dQVLhU7RkL3KHesO3HNX6v4,334
133
- ml_tools/schema/_feature_schema.py,sha256=aVY3AJt1j4D2mtusVy2l6lDR2SYzPMyfvG1o9zOn0Kw,8585
133
+ ml_tools/schema/_feature_schema.py,sha256=MuPf6Nf7tDhUTGyX7tcFHZh-lLSNsJkLmlf9IxdF4O4,9660
134
134
  ml_tools/schema/_gui_schema.py,sha256=IVwN4THAdFrvh2TpV4SFd_zlzMX3eioF-w-qcSVTndE,7245
135
135
  ml_tools/serde/__init__.py,sha256=IDirr8i-qjUHB71hmHO6lGiODhUoOnUcXYrvb_XgrzE,292
136
136
  ml_tools/serde/_serde.py,sha256=8QnYK8ZG21zdNaC0v63iSz2bhgwOKRKAWxTVQvMV0A8,5525
137
137
  ml_tools/utilities/__init__.py,sha256=iQb-S5JesEjGGI8983Vkj-14LCtchFxdWRhaziyvnoY,808
138
138
  ml_tools/utilities/_utility_save_load.py,sha256=EFvFaTaHahDQWdJWZr-j7cHqRbG_Xrpc96228JhV-bs,16773
139
139
  ml_tools/utilities/_utility_tools.py,sha256=bN0J9d1S0W5wNzNntBWqDsJcEAK7-1OgQg3X2fwXns0,6918
140
- dragon_ml_toolbox-20.4.0.dist-info/METADATA,sha256=5r7luC7aniRGcoQ5qy94fFLwme7UldbcfXFI-m_6hlA,7866
141
- dragon_ml_toolbox-20.4.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
142
- dragon_ml_toolbox-20.4.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
143
- dragon_ml_toolbox-20.4.0.dist-info/RECORD,,
140
+ dragon_ml_toolbox-20.6.0.dist-info/METADATA,sha256=HfSazpvNdCk-0TW27NgJuerpBdsrzGhmmUnO3g1FMe4,7866
141
+ dragon_ml_toolbox-20.6.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
142
+ dragon_ml_toolbox-20.6.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
143
+ dragon_ml_toolbox-20.6.0.dist-info/RECORD,,
@@ -10,7 +10,8 @@ from ._dragon_cleaner import (
10
10
  )
11
11
 
12
12
  from ._clean_tools import (
13
- save_unique_values
13
+ save_unique_values,
14
+ save_category_counts,
14
15
  )
15
16
 
16
17
  from .._core import _imprimir_disponibles
@@ -20,6 +21,7 @@ __all__ = [
20
21
  "DragonColumnCleaner",
21
22
  "DragonDataFrameCleaner",
22
23
  "save_unique_values",
24
+ "save_category_counts",
23
25
  "basic_clean",
24
26
  "basic_clean_drop",
25
27
  "drop_macro_polars",
@@ -13,6 +13,7 @@ _LOGGER = get_logger("ETL Clean Tools")
13
13
 
14
14
  __all__ = [
15
15
  "save_unique_values",
16
+ "save_category_counts",
16
17
  ]
17
18
 
18
19
 
@@ -126,3 +127,111 @@ def save_unique_values(csv_path_or_df: Union[str, Path, pl.DataFrame],
126
127
  counter += 1
127
128
 
128
129
  _LOGGER.info(f"{counter} files of unique values created.")
130
+
131
+
132
+ ################ Category Counts per column #################
133
+ def save_category_counts(csv_path_or_df: Union[str, Path, pl.DataFrame],
134
+ output_dir: Union[str, Path],
135
+ use_columns: Optional[list[str]] = None,
136
+ verbose: bool = False,
137
+ keep_column_order: bool = True) -> None:
138
+ """
139
+ Calculates the frequency and percentage of each unique value in the specified columns
140
+ and saves the distribution report to a text file.
141
+
142
+ Useful for checking class balance or identifying rare categories.
143
+
144
+ Args:
145
+ csv_path_or_df (str | Path | pl.DataFrame):
146
+ The file path to the input CSV file or a Polars DataFrame.
147
+ output_dir (str | Path):
148
+ The directory where the report files will be saved.
149
+ use_columns (List[str] | None):
150
+ Columns to analyze. If None, all columns are processed.
151
+ verbose (bool):
152
+ If True, prints progress info.
153
+ keep_column_order (bool):
154
+ If True, prepends a numeric prefix to filenames to maintain order.
155
+ """
156
+ # 1. Handle Input
157
+ if isinstance(csv_path_or_df, pl.DataFrame):
158
+ df = csv_path_or_df
159
+ if use_columns:
160
+ valid_cols = [c for c in use_columns if c in df.columns]
161
+ if not valid_cols:
162
+ _LOGGER.error("None of the specified columns in 'use_columns' exist in the provided DataFrame.")
163
+ raise ValueError()
164
+ df = df.select(valid_cols)
165
+ else:
166
+ csv_path = make_fullpath(input_path=csv_path_or_df, enforce="file")
167
+ df = load_dataframe(df_path=csv_path, use_columns=use_columns, kind="polars", all_strings=True)[0]
168
+
169
+ output_path = make_fullpath(input_path=output_dir, make=True, enforce='directory')
170
+ total_rows = df.height
171
+
172
+ if total_rows == 0:
173
+ _LOGGER.warning("Input DataFrame is empty. No counts to save.")
174
+ return
175
+
176
+ counter = 0
177
+
178
+ # 2. Process Each Column
179
+ for i, col_name in enumerate(df.columns):
180
+ try:
181
+ # Group by, count, and calculate percentage
182
+ # We treat nulls as a category here to see missing data frequency
183
+ stats = (
184
+ df.select(pl.col(col_name))
185
+ .group_by(col_name, maintain_order=False)
186
+ .len(name="count")
187
+ .with_columns(
188
+ (pl.col("count") / total_rows * 100).alias("pct")
189
+ )
190
+ .sort("count", descending=True)
191
+ )
192
+
193
+ # Collect to python list of dicts for writing
194
+ rows = stats.iter_rows(named=True)
195
+ unique_count = stats.height
196
+
197
+ # Check thresholds for warning
198
+ is_high_cardinality = (unique_count > 300) or ((unique_count / total_rows) > 0.5)
199
+
200
+ except Exception:
201
+ _LOGGER.error(f"Could not calculate counts for column '{col_name}'.")
202
+ continue
203
+
204
+ # 3. Write to File
205
+ sanitized_name = sanitize_filename(col_name)
206
+ if not sanitized_name.strip('_'):
207
+ sanitized_name = f'column_{i}'
208
+
209
+ prefix = f"{i + 1}_" if keep_column_order else ''
210
+ file_path = output_path / f"{prefix}{sanitized_name}_counts.txt"
211
+
212
+ try:
213
+ with open(file_path, 'w', encoding='utf-8') as f:
214
+ f.write(f"# Distribution for column: '{col_name}'\n")
215
+ f.write(f"# Total Rows: {total_rows} | Unique Values: {unique_count}\n")
216
+
217
+ if is_high_cardinality:
218
+ f.write(f"# WARNING: High cardinality detected (Unique/Total ratio: {unique_count/total_rows:.2%}).\n")
219
+
220
+ f.write("-" * 65 + "\n")
221
+ f.write(f"{'Count':<10} | {'Percentage':<12} | {'Value'}\n")
222
+ f.write("-" * 65 + "\n")
223
+
224
+ for row in rows:
225
+ val = str(row[col_name])
226
+ count = row["count"]
227
+ pct = row["pct"]
228
+ f.write(f"{count:<10} | {pct:>10.2f}% | {val}\n")
229
+
230
+ except IOError:
231
+ _LOGGER.exception(f"Error writing to file {file_path}.")
232
+ else:
233
+ if verbose:
234
+ print(f" Saved distribution for '{col_name}'.")
235
+ counter += 1
236
+
237
+ _LOGGER.info(f"{counter} distribution files created.")
@@ -1,13 +1,13 @@
1
1
  import polars as pl
2
2
  from pathlib import Path
3
- from typing import Union
3
+ from typing import Union, Optional
4
4
 
5
5
  from ..utilities import save_dataframe_filename, load_dataframe
6
6
 
7
7
  from .._core import get_logger
8
8
  from ..path_manager import make_fullpath
9
9
 
10
- from ._clean_tools import save_unique_values
10
+ from ._clean_tools import save_unique_values, save_category_counts
11
11
 
12
12
 
13
13
  _LOGGER = get_logger("DragonCleaner")
@@ -33,12 +33,18 @@ class DragonColumnCleaner:
33
33
  """
34
34
  def __init__(self,
35
35
  column_name: str,
36
- rules: Union[dict[str, Union[str, None]], dict[str, str]],
36
+ exact_matches: Optional[Union[dict[str, Union[str, None]], dict[str, str]]] = None,
37
+ rules: Optional[Union[dict[str, Union[str, None]], dict[str, str]]] = None,
37
38
  case_insensitive: bool = False):
38
39
  """
39
40
  Args:
40
41
  column_name (str):
41
42
  The name of the column to be cleaned.
43
+ exact_matches (Dict[str, str | None]):
44
+ A dictionary of EXACT string matches to replacement strings.
45
+ - Uses a hash map, which is significantly faster than regex.
46
+ - Used for simple 1-to-1 mappings (e.g., {'Aluminum': 'Al'}).
47
+ - Runs BEFORE the regex rules.
42
48
  rules (Dict[str, str | None]):
43
49
  A dictionary of regex patterns to replacement strings.
44
50
  - Replacement can be None to indicate that matching values should be converted to null.
@@ -61,25 +67,47 @@ class DragonColumnCleaner:
61
67
  if not isinstance(column_name, str) or not column_name:
62
68
  _LOGGER.error("The 'column_name' must be a non-empty string.")
63
69
  raise TypeError()
64
- if not isinstance(rules, dict):
65
- _LOGGER.error("The 'rules' argument must be a dictionary.")
66
- raise TypeError()
67
- # validate rules
68
- for pattern, replacement in rules.items():
69
- if not isinstance(pattern, str):
70
- _LOGGER.error("All keys in 'rules' must be strings representing regex patterns.")
70
+
71
+ # Validate Regex Rules
72
+ if rules is not None:
73
+ if not isinstance(rules, dict):
74
+ _LOGGER.error("The 'rules' argument must be a dictionary.")
71
75
  raise TypeError()
72
- if replacement is not None and not isinstance(replacement, str):
73
- _LOGGER.error("All values in 'rules' must be strings or None (for nullification).")
76
+ for pattern, replacement in rules.items():
77
+ if not isinstance(pattern, str):
78
+ _LOGGER.error("All keys in 'rules' must be strings representing regex patterns.")
79
+ raise TypeError()
80
+ if replacement is not None and not isinstance(replacement, str):
81
+ _LOGGER.error("All values in 'rules' must be strings or None (for nullification).")
82
+ raise TypeError()
83
+
84
+ # Validate Exact Matches
85
+ if exact_matches is not None:
86
+ if not isinstance(exact_matches, dict):
87
+ _LOGGER.error("The 'exact_matches' argument must be a dictionary.")
74
88
  raise TypeError()
89
+ for key, val in exact_matches.items():
90
+ if not isinstance(key, str):
91
+ _LOGGER.error("All keys in 'exact_matches' must be strings.")
92
+ raise TypeError()
93
+ if val is not None and not isinstance(val, str):
94
+ _LOGGER.error("All values in 'exact_matches' must be strings or None.")
95
+ raise TypeError()
96
+
97
+ # Raise if both are None or empty
98
+ if not rules and not exact_matches:
99
+ _LOGGER.error("At least one of 'rules' or 'exact_matches' must be provided.")
100
+ raise ValueError()
75
101
 
76
102
  self.column_name = column_name
77
- self.rules = rules
103
+ self.rules = rules if rules else {}
104
+ self.exact_matches = exact_matches if exact_matches else {}
78
105
  self.case_insensitive = case_insensitive
79
106
 
80
107
  def preview(self,
81
108
  csv_path: Union[str, Path],
82
109
  report_dir: Union[str, Path],
110
+ show_distribution: bool = True,
83
111
  add_value_separator: bool=False,
84
112
  rule_batch_size: int = 150):
85
113
  """
@@ -90,6 +118,8 @@ class DragonColumnCleaner:
90
118
  The path to the CSV file containing the data to clean.
91
119
  report_dir (str | Path):
92
120
  The directory where the preview report will be saved.
121
+ show_distribution (bool):
122
+ If True, generates a category count report for the column after cleaning.
93
123
  add_value_separator (bool):
94
124
  If True, adds a separator line between each unique value in the report.
95
125
  rule_batch_size (int):
@@ -101,13 +131,21 @@ class DragonColumnCleaner:
101
131
  preview_cleaner = DragonDataFrameCleaner(cleaners=[self])
102
132
  df_preview = preview_cleaner.clean(df, rule_batch_size=rule_batch_size)
103
133
 
104
- # Apply cleaning rules to a copy of the column for preview
134
+ # Apply cleaning rules and save reports
105
135
  save_unique_values(csv_path_or_df=df_preview,
106
136
  output_dir=report_dir,
107
137
  use_columns=[self.column_name],
108
138
  verbose=False,
109
139
  keep_column_order=False,
110
140
  add_value_separator=add_value_separator)
141
+
142
+ # Optionally save category counts
143
+ if show_distribution:
144
+ save_category_counts(csv_path_or_df=df_preview,
145
+ output_dir=report_dir,
146
+ use_columns=[self.column_name],
147
+ verbose=False,
148
+ keep_column_order=False)
111
149
 
112
150
 
113
151
  class DragonDataFrameCleaner:
@@ -181,16 +219,23 @@ class DragonDataFrameCleaner:
181
219
  for cleaner in self.cleaners:
182
220
  col_name = cleaner.column_name
183
221
 
184
- # Get all rules as a list of items
222
+ # Start expression for this batch
223
+ col_expr = pl.col(col_name).cast(pl.String)
224
+
225
+ # --- PHASE 1: EXACT MATCHES ---
226
+ # Apply dictionary-based replacement first (faster than regex)
227
+ if cleaner.exact_matches:
228
+ # 'replace' handles dictionary mapping safely. If value is mapped to None, it becomes null.
229
+ col_expr = col_expr.replace(cleaner.exact_matches)
230
+
231
+ # --- PHASE 2: REGEX PATTERNS ---
185
232
  all_rules = list(cleaner.rules.items())
186
233
 
187
234
  # Process in batches of 'rule_batch_size'
188
235
  for i in range(0, len(all_rules), rule_batch_size):
189
236
  rule_batch = all_rules[i : i + rule_batch_size]
190
237
 
191
- # Start expression for this batch
192
- col_expr = pl.col(col_name).cast(pl.String)
193
-
238
+ # continue chaining operations on the same col_expr
194
239
  for pattern, replacement in rule_batch:
195
240
  final_pattern = f"(?i){pattern}" if cleaner.case_insensitive else pattern
196
241
 
@@ -202,6 +247,15 @@ class DragonDataFrameCleaner:
202
247
  col_expr = col_expr.str.replace_all(final_pattern, replacement)
203
248
 
204
249
  # Apply this batch of rules to the LazyFrame
250
+ # apply partially here to keep the logical plan size under control
251
+ final_lf = final_lf.with_columns(col_expr.alias(col_name))
252
+
253
+ # Reset col_expr for the next batch, but pointing to the 'new' column
254
+ # This ensures the next batch works on the result of the previous batch
255
+ col_expr = pl.col(col_name)
256
+
257
+ # If we had exact matches but NO regex rules, we still need to apply the expression once
258
+ if cleaner.exact_matches and not all_rules:
205
259
  final_lf = final_lf.with_columns(col_expr.alias(col_name))
206
260
 
207
261
  # 3. Collect Results
@@ -242,4 +296,3 @@ class DragonDataFrameCleaner:
242
296
  save_dataframe_filename(df=df_clean, save_dir=output_filepath.parent, filename=output_filepath.name)
243
297
 
244
298
  return None
245
-
@@ -1,4 +1,4 @@
1
- from typing import Union
1
+ from typing import Union, Literal
2
2
 
3
3
 
4
4
  __all__ = [
@@ -26,7 +26,7 @@ class _BaseClassificationFormat:
26
26
  def __init__(self,
27
27
  cmap: str="BuGn",
28
28
  ROC_PR_line: str='darkorange',
29
- calibration_bins: int=15,
29
+ calibration_bins: Union[int, Literal['auto']]='auto',
30
30
  xtick_size: int=22,
31
31
  ytick_size: int=22,
32
32
  legend_size: int=26,
@@ -46,8 +46,8 @@ class _BaseClassificationFormat:
46
46
  - Common color names: 'darkorange', 'cornflowerblue', 'crimson', 'forestgreen'
47
47
  - Hex codes: '#FF6347', '#4682B4'
48
48
 
49
- calibration_bins (int): The number of bins to use when
50
- creating the calibration (reliability) plot.
49
+ calibration_bins (int | 'auto'): The number of bins to use when creating the calibration (reliability) plot. If 'auto', the number will be dynamically determined based on the number of samples.
50
+ - Typical int values: 10, 15, 20
51
51
 
52
52
  font_size (int): The base font size to apply to the plots.
53
53
 
@@ -97,6 +97,7 @@ class _BaseMultiLabelFormat:
97
97
  def __init__(self,
98
98
  cmap: str = "BuGn",
99
99
  ROC_PR_line: str='darkorange',
100
+ calibration_bins: Union[int, Literal['auto']]='auto',
100
101
  font_size: int = 25,
101
102
  xtick_size: int=20,
102
103
  ytick_size: int=20,
@@ -115,6 +116,9 @@ class _BaseMultiLabelFormat:
115
116
  - Common color names: 'darkorange', 'cornflowerblue', 'crimson', 'forestgreen'
116
117
  - Hex codes: '#FF6347', '#4682B4'
117
118
 
119
+ calibration_bins (int | 'auto'): The number of bins to use when creating the calibration (reliability) plots for each label. If 'auto', the number will be dynamically determined based on the number of samples.
120
+ - Typical int values: 10, 15, 20
121
+
118
122
  font_size (int): The base font size to apply to the plots.
119
123
 
120
124
  xtick_size (int): Font size for x-axis tick labels.
@@ -133,6 +137,7 @@ class _BaseMultiLabelFormat:
133
137
  """
134
138
  self.cmap = cmap
135
139
  self.ROC_PR_line = ROC_PR_line
140
+ self.calibration_bins = calibration_bins
136
141
  self.font_size = font_size
137
142
  self.xtick_size = xtick_size
138
143
  self.ytick_size = ytick_size
@@ -142,6 +147,7 @@ class _BaseMultiLabelFormat:
142
147
  parts = [
143
148
  f"cmap='{self.cmap}'",
144
149
  f"ROC_PR_line='{self.ROC_PR_line}'",
150
+ f"calibration_bins={self.calibration_bins}",
145
151
  f"font_size={self.font_size}",
146
152
  f"xtick_size={self.xtick_size}",
147
153
  f"ytick_size={self.ytick_size}",
@@ -416,7 +422,7 @@ class FormatBinaryClassificationMetrics(_BaseClassificationFormat):
416
422
  def __init__(self,
417
423
  cmap: str="BuGn",
418
424
  ROC_PR_line: str='darkorange',
419
- calibration_bins: int=15,
425
+ calibration_bins: Union[int, Literal['auto']]='auto',
420
426
  font_size: int=26,
421
427
  xtick_size: int=22,
422
428
  ytick_size: int=22,
@@ -440,7 +446,7 @@ class FormatMultiClassClassificationMetrics(_BaseClassificationFormat):
440
446
  def __init__(self,
441
447
  cmap: str="BuGn",
442
448
  ROC_PR_line: str='darkorange',
443
- calibration_bins: int=15,
449
+ calibration_bins: Union[int, Literal['auto']]='auto',
444
450
  font_size: int=26,
445
451
  xtick_size: int=22,
446
452
  ytick_size: int=22,
@@ -464,7 +470,7 @@ class FormatBinaryImageClassificationMetrics(_BaseClassificationFormat):
464
470
  def __init__(self,
465
471
  cmap: str="BuGn",
466
472
  ROC_PR_line: str='darkorange',
467
- calibration_bins: int=15,
473
+ calibration_bins: Union[int, Literal['auto']]='auto',
468
474
  font_size: int=26,
469
475
  xtick_size: int=22,
470
476
  ytick_size: int=22,
@@ -488,7 +494,7 @@ class FormatMultiClassImageClassificationMetrics(_BaseClassificationFormat):
488
494
  def __init__(self,
489
495
  cmap: str="BuGn",
490
496
  ROC_PR_line: str='darkorange',
491
- calibration_bins: int=15,
497
+ calibration_bins: Union[int, Literal['auto']]='auto',
492
498
  font_size: int=26,
493
499
  xtick_size: int=22,
494
500
  ytick_size: int=22,
@@ -513,6 +519,7 @@ class FormatMultiLabelBinaryClassificationMetrics(_BaseMultiLabelFormat):
513
519
  def __init__(self,
514
520
  cmap: str = "BuGn",
515
521
  ROC_PR_line: str='darkorange',
522
+ calibration_bins: Union[int, Literal['auto']]='auto',
516
523
  font_size: int = 25,
517
524
  xtick_size: int=20,
518
525
  ytick_size: int=20,
@@ -520,6 +527,7 @@ class FormatMultiLabelBinaryClassificationMetrics(_BaseMultiLabelFormat):
520
527
  ) -> None:
521
528
  super().__init__(cmap=cmap,
522
529
  ROC_PR_line=ROC_PR_line,
530
+ calibration_bins=calibration_bins,
523
531
  font_size=font_size,
524
532
  xtick_size=xtick_size,
525
533
  ytick_size=ytick_size,
@@ -2,7 +2,7 @@ import numpy as np
2
2
  import pandas as pd
3
3
  import matplotlib.pyplot as plt
4
4
  import seaborn as sns
5
- from sklearn.calibration import CalibrationDisplay
5
+ from sklearn.calibration import calibration_curve
6
6
  from sklearn.metrics import (
7
7
  classification_report,
8
8
  ConfusionMatrixDisplay,
@@ -378,42 +378,42 @@ def classification_metrics(save_dir: Union[str, Path],
378
378
 
379
379
  # --- Save Calibration Plot ---
380
380
  fig_cal, ax_cal = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
381
+
382
+ user_chosen_bins = format_config.calibration_bins
383
+
384
+ # --- Automate Bin Selection ---
385
+ if not isinstance(user_chosen_bins, int) or user_chosen_bins <= 0:
386
+ # Determine bins based on number of samples
387
+ n_samples = y_true.shape[0]
388
+ if n_samples < 200:
389
+ dynamic_bins = 5
390
+ elif n_samples < 1000:
391
+ dynamic_bins = 10
392
+ else:
393
+ dynamic_bins = 15
394
+ else:
395
+ dynamic_bins = user_chosen_bins
396
+
397
+ # --- Step 1: Get binned data directly ---
398
+ # calculates reliability diagram data without needing a temporary plot
399
+ prob_true, prob_pred = calibration_curve(y_true_binary, y_score, n_bins=dynamic_bins)
381
400
 
382
- # --- Step 1: Get binned data *without* plotting ---
383
- with plt.ioff(): # Suppress showing the temporary plot
384
- fig_temp, ax_temp = plt.subplots()
385
- cal_display_temp = CalibrationDisplay.from_predictions(
386
- y_true_binary, # Use binarized labels
387
- y_score,
388
- n_bins=format_config.calibration_bins,
389
- ax=ax_temp,
390
- name="temp" # Add a name to suppress potential warnings
391
- )
392
- # Get the x, y coordinates of the binned data
393
- line_x, line_y = cal_display_temp.line_.get_data() # type: ignore
394
- plt.close(fig_temp) # Close the temporary plot
395
-
396
- # --- Step 2: Build the plot from scratch ---
401
+ # --- Step 2: Plot ---
397
402
  ax_cal.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')
398
403
 
399
- sns.regplot(
400
- x=line_x,
401
- y=line_y,
402
- ax=ax_cal,
403
- scatter=False,
404
- label=f"Model calibration",
405
- line_kws={
406
- 'color': format_config.ROC_PR_line,
407
- 'linestyle': '--',
408
- 'linewidth': 2,
409
- }
410
- )
404
+ # Plot the actual calibration curve (connect points with a line)
405
+ ax_cal.plot(prob_pred,
406
+ prob_true,
407
+ marker='o', # Add markers to see bin locations
408
+ linewidth=2,
409
+ label="Model calibration",
410
+ color=format_config.ROC_PR_line)
411
411
 
412
412
  ax_cal.set_title(f'Reliability Curve{plot_title}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size + 2)
413
413
  ax_cal.set_xlabel('Mean Predicted Probability', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
414
414
  ax_cal.set_ylabel('Fraction of Positives', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
415
415
 
416
- # --- Step 3: Set final limits *after* plotting ---
416
+ # --- Step 3: Set final limits ---
417
417
  ax_cal.set_ylim(0.0, 1.0)
418
418
  ax_cal.set_xlim(0.0, 1.0)
419
419
 
@@ -428,7 +428,7 @@ def classification_metrics(save_dir: Union[str, Path],
428
428
  cal_path = save_dir_path / f"calibration_plot{save_suffix}.svg"
429
429
  plt.savefig(cal_path)
430
430
  plt.close(fig_cal)
431
-
431
+
432
432
  _LOGGER.info(f"📈 Saved {len(class_indices_to_plot)} sets of ROC, Precision-Recall, and Calibration plots.")
433
433
 
434
434
 
@@ -632,6 +632,52 @@ def multi_label_classification_metrics(
632
632
  pr_path = save_dir_path / f"pr_curve_{sanitized_name}.svg"
633
633
  plt.savefig(pr_path)
634
634
  plt.close(fig_pr)
635
+
636
+ # --- Save Calibration Plot (New Feature) ---
637
+ fig_cal, ax_cal = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
638
+
639
+ user_chosen_bins = format_config.calibration_bins
640
+
641
+ # --- Automate Bin Selection ---
642
+ if not isinstance(user_chosen_bins, int) or user_chosen_bins <= 0:
643
+ # Determine bins based on number of samples
644
+ n_samples = y_true.shape[0]
645
+ if n_samples < 200:
646
+ dynamic_bins = 5
647
+ elif n_samples < 1000:
648
+ dynamic_bins = 10
649
+ else:
650
+ dynamic_bins = 15
651
+ else:
652
+ dynamic_bins = user_chosen_bins
653
+
654
+ # Calculate calibration curve for this specific label
655
+ prob_true, prob_pred = calibration_curve(true_i, prob_i, n_bins=dynamic_bins)
656
+
657
+ ax_cal.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')
658
+ ax_cal.plot(prob_pred,
659
+ prob_true,
660
+ marker='o',
661
+ linewidth=2,
662
+ label=f"Calibration for '{name}'",
663
+ color=format_config.ROC_PR_line)
664
+
665
+ ax_cal.set_title(f'Reliability Curve for "{name}"', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
666
+ ax_cal.set_xlabel('Mean Predicted Probability', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
667
+ ax_cal.set_ylabel('Fraction of Positives', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
668
+
669
+ ax_cal.set_ylim(0.0, 1.0)
670
+ ax_cal.set_xlim(0.0, 1.0)
671
+
672
+ ax_cal.tick_params(axis='x', labelsize=xtick_size)
673
+ ax_cal.tick_params(axis='y', labelsize=ytick_size)
674
+ ax_cal.legend(loc='lower right', fontsize=legend_size)
675
+ ax_cal.grid(True)
676
+
677
+ plt.tight_layout()
678
+ cal_path = save_dir_path / f"calibration_plot_{sanitized_name}.svg"
679
+ plt.savefig(cal_path)
680
+ plt.close(fig_cal)
635
681
 
636
682
  _LOGGER.info(f"All individual label reports and plots saved to '{save_dir_path.name}'")
637
683
 
@@ -170,9 +170,13 @@ class DragonParetoOptimizer:
170
170
  re_evaluate=False # model is deterministic
171
171
  )
172
172
 
173
- def run(self) -> pd.DataFrame:
173
+ def run(self,
174
+ plots_and_log: bool=True) -> pd.DataFrame:
174
175
  """
175
176
  Execute the optimization with progress tracking and periodic logging.
177
+
178
+ Args:
179
+ plots_and_log (bool): If True, generates plots and logs during optimization. Disable for multi-run scenarios.
176
180
 
177
181
  Returns:
178
182
  pd.DataFrame: A DataFrame containing the non-dominated solutions (Pareto Front).
@@ -189,9 +193,10 @@ class DragonParetoOptimizer:
189
193
  _LOGGER.info(f"🧬 Starting NSGA-II (GeneticAlgorithm) for {generations} generations...")
190
194
 
191
195
  # Initialize log file
192
- with open(log_file, "w") as f:
193
- f.write(f"Pareto Optimization Log - {generations} Generations\n")
194
- f.write("=" * 60 + "\n")
196
+ if plots_and_log:
197
+ with open(log_file, "w") as f:
198
+ f.write(f"Pareto Optimization Log - {generations} Generations\n")
199
+ f.write("=" * 60 + "\n")
195
200
 
196
201
  # History tracking for visualization
197
202
  history_records = []
@@ -201,43 +206,44 @@ class DragonParetoOptimizer:
201
206
  for gen in range(1, generations + 1):
202
207
  self.algorithm.step()
203
208
 
204
- # Capture stats for history (every generation for smooth plots)
205
- current_evals = self.algorithm.population.evals.clone() # type: ignore
206
-
207
- gen_stats = {}
208
- for i, target_name in enumerate(self.ordered_target_names):
209
- vals = current_evals[:, i]
210
- v_mean = float(vals.mean())
211
- v_min = float(vals.min())
212
- v_max = float(vals.max())
213
-
214
- # Store for plotting
215
- history_records.append({
216
- "Generation": gen,
217
- "Target": target_name,
218
- "Mean": v_mean,
219
- "Min": v_min,
220
- "Max": v_max
221
- })
222
-
223
- gen_stats[target_name] = (v_mean, v_min, v_max)
224
-
225
- # Periodic Logging of Population Stats to FILE
226
- if gen % log_interval == 0 or gen == generations:
227
- stats_msg = [f"Gen {gen}:"]
228
- for t_name, (v_mean, v_min, v_max) in gen_stats.items():
229
- stats_msg.append(f"{t_name}: {v_mean:.3f} (Range: {v_min:.3f}-{v_max:.3f})")
209
+ if plots_and_log:
210
+ # Capture stats for history (every generation for smooth plots)
211
+ current_evals = self.algorithm.population.evals.clone() # type: ignore
230
212
 
231
- log_line = " | ".join(stats_msg)
213
+ gen_stats = {}
214
+ for i, target_name in enumerate(self.ordered_target_names):
215
+ vals = current_evals[:, i]
216
+ v_mean = float(vals.mean())
217
+ v_min = float(vals.min())
218
+ v_max = float(vals.max())
219
+
220
+ # Store for plotting
221
+ history_records.append({
222
+ "Generation": gen,
223
+ "Target": target_name,
224
+ "Mean": v_mean,
225
+ "Min": v_min,
226
+ "Max": v_max
227
+ })
228
+
229
+ gen_stats[target_name] = (v_mean, v_min, v_max)
232
230
 
233
- # Write to file
234
- with open(log_file, "a") as f:
235
- f.write(log_line + "\n")
231
+ # Periodic Logging of Population Stats to FILE
232
+ if gen % log_interval == 0 or gen == generations:
233
+ stats_msg = [f"Gen {gen}:"]
234
+ for t_name, (v_mean, v_min, v_max) in gen_stats.items():
235
+ stats_msg.append(f"{t_name}: {v_mean:.3f} (Range: {v_min:.3f}-{v_max:.3f})")
236
+
237
+ log_line = " | ".join(stats_msg)
238
+
239
+ # Write to file
240
+ with open(log_file, "a") as f:
241
+ f.write(log_line + "\n")
236
242
 
237
243
  pbar.update(1)
238
244
 
239
245
  # --- Post-Optimization Visualization ---
240
- if history_records:
246
+ if plots_and_log and history_records:
241
247
  _LOGGER.debug("Generating optimization history plots...")
242
248
  history_df = pd.DataFrame(history_records)
243
249
  self._plot_optimization_history(history_df, save_path)
@@ -308,7 +314,8 @@ class DragonParetoOptimizer:
308
314
  _LOGGER.info(f"Optimization complete. Found {len(pareto_df)} non-dominated solutions.")
309
315
 
310
316
  # --- Plotting ---
311
- self._generate_plots(pareto_df, save_path)
317
+ if plots_and_log:
318
+ self._generate_plots(pareto_df, save_path)
312
319
 
313
320
  return pareto_df
314
321
 
@@ -36,6 +36,7 @@ from ._features import (
36
36
  from ._schema_ops import (
37
37
  finalize_feature_schema,
38
38
  apply_feature_schema,
39
+ reconstruct_from_schema
39
40
  )
40
41
 
41
42
  from .._core import _imprimir_disponibles
@@ -62,6 +63,7 @@ __all__ = [
62
63
  "encode_categorical_features",
63
64
  "finalize_feature_schema",
64
65
  "apply_feature_schema",
66
+ "reconstruct_from_schema",
65
67
  "match_and_filter_columns_by_regex",
66
68
  "standardize_percentages",
67
69
  "reconstruct_one_hot",
@@ -9,6 +9,13 @@ from .._core import get_logger
9
9
  _LOGGER = get_logger("Data Exploration: Schema Ops")
10
10
 
11
11
 
12
+ __all__ = [
13
+ "finalize_feature_schema",
14
+ "apply_feature_schema",
15
+ "reconstruct_from_schema",
16
+ ]
17
+
18
+
12
19
  def finalize_feature_schema(
13
20
  df_features: pd.DataFrame,
14
21
  categorical_mappings: Optional[dict[str, dict[str, int]]]
@@ -86,7 +93,7 @@ def apply_feature_schema(
86
93
  schema: FeatureSchema,
87
94
  targets: Optional[list[str]] = None,
88
95
  unknown_value: int = 99999,
89
- verbose: bool = True
96
+ verbose: int = 3
90
97
  ) -> pd.DataFrame:
91
98
  """
92
99
  Aligns the input DataFrame with the provided FeatureSchema.
@@ -100,7 +107,7 @@ def apply_feature_schema(
100
107
  targets (list[str] | None): Optional list of target column names.
101
108
  unknown_value (int): Integer value to assign to unknown categorical levels.
102
109
  Defaults to 99999 to avoid collision with existing categories.
103
- verbose (bool): If True, logs info about dropped extra columns.
110
+ verbose (int): Verbosity level for logging. Higher values produce more detailed logs.
104
111
 
105
112
  Returns:
106
113
  pd.DataFrame: A new DataFrame with the exact column order and encoding defined by the schema.
@@ -147,7 +154,8 @@ def apply_feature_schema(
147
154
  # Handle Unknown Categories
148
155
  if df_processed[col_name].isnull().any():
149
156
  n_missing = df_processed[col_name].isnull().sum()
150
- _LOGGER.warning(f"Feature '{col_name}': Found {n_missing} unknown categories. Mapping to {unknown_value}.")
157
+ if verbose >= 1:
158
+ _LOGGER.warning(f"Feature '{col_name}': Found {n_missing} unknown categories. Mapping to {unknown_value}.")
151
159
 
152
160
  # Fill unknowns with the specified integer
153
161
  df_processed[col_name] = df_processed[col_name].fillna(unknown_value)
@@ -159,14 +167,13 @@ def apply_feature_schema(
159
167
 
160
168
  extra_cols = set(df_processed.columns) - set(final_column_order)
161
169
  if extra_cols:
162
- _LOGGER.info(f"Dropping {len(extra_cols)} extra columns not present in schema.")
163
- if verbose:
164
- for extra_column in extra_cols:
165
- print(f" - Dropping column: '{extra_column}'")
170
+ if verbose >= 1:
171
+ _LOGGER.warning(f"Dropping {len(extra_cols)} extra columns not present in schema: {extra_cols}")
166
172
 
167
173
  df_final = df_processed[final_column_order]
168
174
 
169
- _LOGGER.info(f"Schema applied successfully. Final shape: {df_final.shape}")
175
+ if verbose >= 2:
176
+ _LOGGER.info(f"Schema applied successfully. Final shape: {df_final.shape}")
170
177
 
171
178
  # df_final should be a dataframe
172
179
  if isinstance(df_final, pd.Series):
@@ -174,3 +181,95 @@ def apply_feature_schema(
174
181
 
175
182
  return df_final
176
183
 
184
+
185
+ def reconstruct_from_schema(
186
+ df: pd.DataFrame,
187
+ schema: FeatureSchema,
188
+ targets: Optional[list[str]] = None,
189
+ verbose: int = 3
190
+ ) -> pd.DataFrame:
191
+ """
192
+ Reverses the schema application to make data human-readable.
193
+
194
+ This function decodes categorical features back to their string representations
195
+ using the schema's mappings. It strictly enforces the schema structure,
196
+ ignoring extra columns (unless they are specified as targets).
197
+
198
+ Args:
199
+ df (pd.DataFrame): The input DataFrame containing encoded features.
200
+ schema (FeatureSchema): The schema defining feature names and reverse mappings.
201
+ targets (list[str] | None): Optional list of target column names to preserve. These are not decoded and kept in the order specified here.
202
+ verbose (int): Verbosity level for logging info about the process.
203
+
204
+ Returns:
205
+ pd.DataFrame: A new DataFrame with the exact column order (features + targets),
206
+ with categorical features decoded to strings.
207
+
208
+ Raises:
209
+ ValueError: If any required feature or target column is missing.
210
+ """
211
+ # 1. Setup
212
+ df_decoded = df.copy()
213
+ targets = targets if targets is not None else []
214
+
215
+ # 2. Validation: Strict Column Presence
216
+ # Check Features
217
+ missing_features = [col for col in schema.feature_names if col not in df_decoded.columns]
218
+ if missing_features:
219
+ _LOGGER.error(f"Schema Reconstruction Mismatch: Missing required features: {missing_features}")
220
+ raise ValueError()
221
+
222
+ # Check Targets
223
+ if targets:
224
+ missing_targets = [col for col in targets if col not in df_decoded.columns]
225
+ if missing_targets:
226
+ _LOGGER.error(f"Schema Reconstruction Mismatch: Missing required targets: {missing_targets}")
227
+ raise ValueError()
228
+
229
+ # 3. Reorder and Filter (Drop extra columns early)
230
+ # The valid columns are Features + Targets
231
+ valid_columns = list(schema.feature_names) + targets
232
+
233
+ extra_cols = set(df_decoded.columns) - set(valid_columns)
234
+ if extra_cols:
235
+ if verbose >= 1:
236
+ _LOGGER.warning(f"Dropping extra columns not present in schema or targets: {extra_cols}")
237
+
238
+ # Enforce order: Features first, then Targets
239
+ df_decoded = df_decoded[valid_columns]
240
+
241
+ # 4. Reverse Categorical Encoding
242
+ if schema.categorical_feature_names and schema.categorical_mappings:
243
+ for col_name in schema.categorical_feature_names:
244
+ if col_name not in schema.categorical_mappings:
245
+ continue
246
+
247
+ forward_mapping = schema.categorical_mappings[col_name]
248
+ # Create reverse map: {int: str}
249
+ reverse_mapping = {v: k for k, v in forward_mapping.items()}
250
+
251
+ # --- SAFE TYPE CASTING ---
252
+ # Ensure values are Integers before mapping (handle 5.0 vs 5).
253
+ try:
254
+ if pd.api.types.is_numeric_dtype(df_decoded[col_name]):
255
+ df_decoded[col_name] = df_decoded[col_name].astype("Int64")
256
+ except (TypeError, ValueError):
257
+ # casted to NaN later during mapping
258
+ pass
259
+ # -------------------------
260
+
261
+ # Check for unknown codes before mapping
262
+ if verbose >= 1:
263
+ unique_codes = df_decoded[col_name].dropna().unique()
264
+ unknown_codes = [code for code in unique_codes if code not in reverse_mapping]
265
+ if unknown_codes:
266
+ _LOGGER.warning(f"Feature '{col_name}': Found unknown encoded values {unknown_codes}. These will be mapped to NaN.")
267
+
268
+ # Apply reverse mapping
269
+ df_decoded[col_name] = df_decoded[col_name].map(reverse_mapping)
270
+
271
+ if verbose >= 2:
272
+ _LOGGER.info(f"Schema reconstruction successful. Final shape: {df_decoded.shape}")
273
+
274
+ return df_decoded
275
+
ml_tools/keys/_keys.py CHANGED
@@ -4,6 +4,7 @@ class MagicWords:
4
4
  CURRENT = "current"
5
5
  RENAME = "rename"
6
6
  UNKNOWN = "unknown"
7
+ AUTO = "auto"
7
8
 
8
9
 
9
10
  class PyTorchLogKeys:
@@ -202,13 +202,39 @@ class FeatureSchema(NamedTuple):
202
202
  filename=DatasetKeys.CATEGORICAL_NAMES,
203
203
  verbose=verbose)
204
204
 
205
- def save_artifacts(self, directory: Union[str,Path]):
205
+ def save_description(self, directory: Union[str, Path], verbose: bool = False) -> None:
206
+ """
207
+ Saves the schema's description to a .txt file.
208
+
209
+ Args:
210
+ directory: The directory where the file will be saved.
211
+ verbose: If True, prints a confirmation message upon saving.
212
+ """
213
+ dir_path = make_fullpath(directory, make=True, enforce="directory")
214
+ filename = "FeatureSchema-description.txt"
215
+ file_path = dir_path / filename
216
+
217
+ try:
218
+ with open(file_path, "w", encoding="utf-8") as f:
219
+ f.write(str(self))
220
+
221
+ if verbose:
222
+ _LOGGER.info(f"Schema description saved to '{dir_path.name}/{filename}'")
223
+ except IOError as e:
224
+ _LOGGER.error(f"Failed to save schema description: {e}")
225
+ raise e
226
+
227
+ def save_artifacts(self, directory: Union[str,Path], verbose: bool=True):
206
228
  """
207
229
  Saves feature names, categorical feature names, continuous feature names to separate text files.
208
230
  """
209
- self.save_all_features(directory=directory, verbose=True)
210
- self.save_continuous_features(directory=directory, verbose=True)
211
- self.save_categorical_features(directory=directory, verbose=True)
231
+ self.save_all_features(directory=directory, verbose=False)
232
+ self.save_continuous_features(directory=directory, verbose=False)
233
+ self.save_categorical_features(directory=directory, verbose=False)
234
+ self.save_description(directory=directory, verbose=False)
235
+
236
+ if verbose:
237
+ _LOGGER.info(f"All FeatureSchema artifacts saved to directory: '{directory}'")
212
238
 
213
239
  def __repr__(self) -> str:
214
240
  """Returns a concise representation of the schema's contents."""