dragon-ml-toolbox 13.6.0__py3-none-any.whl → 13.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-13.6.0.dist-info → dragon_ml_toolbox-13.8.0.dist-info}/METADATA +1 -1
- {dragon_ml_toolbox-13.6.0.dist-info → dragon_ml_toolbox-13.8.0.dist-info}/RECORD +11 -11
- ml_tools/MICE_imputation.py +207 -5
- ml_tools/ML_utilities.py +253 -4
- ml_tools/custom_logger.py +26 -8
- ml_tools/keys.py +8 -0
- ml_tools/utilities.py +178 -0
- {dragon_ml_toolbox-13.6.0.dist-info → dragon_ml_toolbox-13.8.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-13.6.0.dist-info → dragon_ml_toolbox-13.8.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-13.6.0.dist-info → dragon_ml_toolbox-13.8.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-13.6.0.dist-info → dragon_ml_toolbox-13.8.0.dist-info}/top_level.txt +0 -0
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
dragon_ml_toolbox-13.
|
|
2
|
-
dragon_ml_toolbox-13.
|
|
1
|
+
dragon_ml_toolbox-13.8.0.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
|
|
2
|
+
dragon_ml_toolbox-13.8.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=iy2r_R7wjzsCbz_Q_jMsp_jfZ6oP8XW9QhwzRBH0mGY,1904
|
|
3
3
|
ml_tools/ETL_cleaning.py,sha256=2VBRllV8F-ZiPylPp8Az2gwn5ztgazN0BH5OKnRUhV0,20402
|
|
4
4
|
ml_tools/ETL_engineering.py,sha256=KfYqgsxupAx6e_TxwO1LZXeu5mFkIhVXJrNjP3CzIZc,54927
|
|
5
5
|
ml_tools/GUI_tools.py,sha256=Va6ig-dHULPVRwQYYtH3fvY5XPIoqRcJpRW8oXC55Hw,45413
|
|
6
|
-
ml_tools/MICE_imputation.py,sha256=
|
|
6
|
+
ml_tools/MICE_imputation.py,sha256=KLJXGQLKJ6AuWWttAG-LCCaxpS-ygM4dXPiguHDaL6Y,20815
|
|
7
7
|
ml_tools/ML_callbacks.py,sha256=elD2Yr030sv_6gX_m9GVd6HTyrbmt34nFS8lrgS4HtM,15808
|
|
8
8
|
ml_tools/ML_datasetmaster.py,sha256=6caWbq6eu1RE9V51gmceD71PtMctJRjFuLvkkK5ChiY,36271
|
|
9
9
|
ml_tools/ML_evaluation.py,sha256=li77AuP53pCzgrj6p-jTCNtPFgS9Y9XnMWIZn1ulTBM,18946
|
|
@@ -13,7 +13,7 @@ ml_tools/ML_models.py,sha256=UVWJHPLVIvFno_csCHH1FwBfTwQ5nX0V8F1TbOByZ4I,31388
|
|
|
13
13
|
ml_tools/ML_optimization.py,sha256=P0zkhKAwTpkorIBtR0AOIDcyexo5ngmvFUzo3DfNO-E,22692
|
|
14
14
|
ml_tools/ML_scaler.py,sha256=tw6onj9o8_kk3FQYb930HUzvv1zsFZe2YZJdF3LtHkU,7538
|
|
15
15
|
ml_tools/ML_trainer.py,sha256=ZxeOagXW5adFhYIH-oMTlcrLU6VHe4R1EROI7yypNwQ,29665
|
|
16
|
-
ml_tools/ML_utilities.py,sha256=
|
|
16
|
+
ml_tools/ML_utilities.py,sha256=QC44y5mAzA6iUdb3py0bjI-nPjxUatZTdm8sMrb3He0,19364
|
|
17
17
|
ml_tools/PSO_optimization.py,sha256=T-HWHMRJUnPvPwixdU5jif3_rnnI36TzcL8u3oSCwuA,22960
|
|
18
18
|
ml_tools/RNN_forecast.py,sha256=Qa2KoZfdAvSjZ4yE78N4BFXtr3tTr0Gx7tQJZPotsh0,1967
|
|
19
19
|
ml_tools/SQL.py,sha256=vXLPGfVVg8bfkbBE3HVfyEclVbdJy0TBhuQONtMwSCQ,11234
|
|
@@ -23,19 +23,19 @@ ml_tools/_logger.py,sha256=dlp5cGbzooK9YSNSZYB4yjZrOaQUGW8PTrM411AOvL8,4717
|
|
|
23
23
|
ml_tools/_schema.py,sha256=yu6aWmn_2Z4_AxAtJGDDCIa96y6JcUp-vgnCS013Qmw,3908
|
|
24
24
|
ml_tools/_script_info.py,sha256=21r83LV3RubsNZ_RTEUON6RbDf7Mh4_udweNcvdF_Fk,212
|
|
25
25
|
ml_tools/constants.py,sha256=3br5Rk9cL2IUo638eJuMOGdbGQaWssaUecYEvSeRBLM,3322
|
|
26
|
-
ml_tools/custom_logger.py,sha256=
|
|
26
|
+
ml_tools/custom_logger.py,sha256=i0cAr1qPnwXDyqQ1itk2o72-2jniRXJNEuST2eW4zF4,11016
|
|
27
27
|
ml_tools/data_exploration.py,sha256=-BbWO7BBFapPi_7ZuWo65VqguJXaBfgFSptrXyoWrDk,51902
|
|
28
28
|
ml_tools/ensemble_evaluation.py,sha256=FGHSe8LBI8_w8LjNeJWOcYQ1UK_mc6fVah8gmSvNVGg,26853
|
|
29
29
|
ml_tools/ensemble_inference.py,sha256=0yLmLNj45RVVoSCLH1ZYJG9IoAhTkWUqEZmLOQTFGTY,9348
|
|
30
30
|
ml_tools/ensemble_learning.py,sha256=vsIED7nlheYI4w2SBzP6SC1AnNeMfn-2A1Gqw5EfxsM,21964
|
|
31
31
|
ml_tools/handle_excel.py,sha256=pfdAPb9ywegFkM9T54bRssDOsX-K7rSeV0RaMz7lEAo,14006
|
|
32
|
-
ml_tools/keys.py,sha256=
|
|
32
|
+
ml_tools/keys.py,sha256=CcqE9R9R32osR0vLz0i-3cyv1UlVsDWAHqvlVf8xm_0,2492
|
|
33
33
|
ml_tools/math_utilities.py,sha256=xeKq1quR_3DYLgowcp4Uam_4s3JltUyOnqMOGuAiYWU,8802
|
|
34
34
|
ml_tools/optimization_tools.py,sha256=TYFQ2nSnp7xxs-VyoZISWgnGJghFbsWasHjruegyJRs,12763
|
|
35
35
|
ml_tools/path_manager.py,sha256=CyDU16pOKmC82jPubqJPT6EBt-u-3rGVbxyPIZCvDDY,18432
|
|
36
36
|
ml_tools/serde.py,sha256=c8uDYjYry_VrLvoG4ixqDj5pij88lVn6Tu4NHcPkwDU,6943
|
|
37
|
-
ml_tools/utilities.py,sha256=
|
|
38
|
-
dragon_ml_toolbox-13.
|
|
39
|
-
dragon_ml_toolbox-13.
|
|
40
|
-
dragon_ml_toolbox-13.
|
|
41
|
-
dragon_ml_toolbox-13.
|
|
37
|
+
ml_tools/utilities.py,sha256=aWqvYzmxlD74PD5Yqu1VuTekDJeYLQrmPIU_VeVyRp0,22526
|
|
38
|
+
dragon_ml_toolbox-13.8.0.dist-info/METADATA,sha256=mvK0WY75d25CARpUbiDoaK3PHtVgRIEcCauCo7RT6wU,6166
|
|
39
|
+
dragon_ml_toolbox-13.8.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
40
|
+
dragon_ml_toolbox-13.8.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
|
|
41
|
+
dragon_ml_toolbox-13.8.0.dist-info/RECORD,,
|
ml_tools/MICE_imputation.py
CHANGED
|
@@ -7,19 +7,20 @@ from plotnine import ggplot, labs, theme, element_blank # type: ignore
|
|
|
7
7
|
from typing import Optional, Union
|
|
8
8
|
|
|
9
9
|
from .utilities import load_dataframe, merge_dataframes, save_dataframe_filename
|
|
10
|
-
from .math_utilities import threshold_binary_values
|
|
10
|
+
from .math_utilities import threshold_binary_values, discretize_categorical_values
|
|
11
11
|
from .path_manager import sanitize_filename, make_fullpath, list_csv_paths
|
|
12
12
|
from ._logger import _LOGGER
|
|
13
13
|
from ._script_info import _script_info
|
|
14
|
+
from ._schema import FeatureSchema
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
__all__ = [
|
|
18
|
+
"MiceImputer",
|
|
17
19
|
"apply_mice",
|
|
18
20
|
"save_imputed_datasets",
|
|
19
|
-
"get_na_column_names",
|
|
20
21
|
"get_convergence_diagnostic",
|
|
21
22
|
"get_imputed_distributions",
|
|
22
|
-
"run_mice_pipeline"
|
|
23
|
+
"run_mice_pipeline",
|
|
23
24
|
]
|
|
24
25
|
|
|
25
26
|
|
|
@@ -79,7 +80,7 @@ def save_imputed_datasets(save_dir: Union[str, Path], imputed_datasets: list, df
|
|
|
79
80
|
|
|
80
81
|
|
|
81
82
|
#Get names of features that had missing values before imputation
|
|
82
|
-
def
|
|
83
|
+
def _get_na_column_names(df: pd.DataFrame):
|
|
83
84
|
return [col for col in df.columns if df[col].isna().any()]
|
|
84
85
|
|
|
85
86
|
|
|
@@ -264,7 +265,7 @@ def run_mice_pipeline(df_path_or_dir: Union[str,Path], target_columns: list[str]
|
|
|
264
265
|
|
|
265
266
|
save_imputed_datasets(save_dir=save_datasets_path, imputed_datasets=imputed_datasets, df_targets=df_targets, imputed_dataset_names=imputed_dataset_names)
|
|
266
267
|
|
|
267
|
-
imputed_column_names =
|
|
268
|
+
imputed_column_names = _get_na_column_names(df=df)
|
|
268
269
|
|
|
269
270
|
get_convergence_diagnostic(kernel=kernel, imputed_dataset_names=imputed_dataset_names, column_names=imputed_column_names, root_dir=save_metrics_path)
|
|
270
271
|
|
|
@@ -278,5 +279,206 @@ def _skip_targets(df: pd.DataFrame, target_cols: list[str]):
|
|
|
278
279
|
return df_feats, df_targets
|
|
279
280
|
|
|
280
281
|
|
|
282
|
+
# modern implementation
|
|
283
|
+
class MiceImputer:
|
|
284
|
+
"""
|
|
285
|
+
A modern MICE imputation pipeline that uses a FeatureSchema
|
|
286
|
+
to correctly discretize categorical features after imputation.
|
|
287
|
+
"""
|
|
288
|
+
def __init__(self,
|
|
289
|
+
schema: FeatureSchema,
|
|
290
|
+
iterations: int=20,
|
|
291
|
+
resulting_datasets: int = 1,
|
|
292
|
+
random_state: int = 101):
|
|
293
|
+
|
|
294
|
+
self.schema = schema
|
|
295
|
+
self.random_state = random_state
|
|
296
|
+
self.iterations = iterations
|
|
297
|
+
self.resulting_datasets = resulting_datasets
|
|
298
|
+
|
|
299
|
+
# --- Store schema info ---
|
|
300
|
+
|
|
301
|
+
# 1. Categorical info
|
|
302
|
+
if not self.schema.categorical_index_map:
|
|
303
|
+
_LOGGER.warning("FeatureSchema has no 'categorical_index_map'. No discretization will be applied.")
|
|
304
|
+
self.cat_info = {}
|
|
305
|
+
else:
|
|
306
|
+
self.cat_info = self.schema.categorical_index_map
|
|
307
|
+
|
|
308
|
+
# 2. Ordered feature names (critical for index mapping)
|
|
309
|
+
self.ordered_features = list(self.schema.feature_names)
|
|
310
|
+
|
|
311
|
+
# 3. Names of categorical features
|
|
312
|
+
self.categorical_features = list(self.schema.categorical_feature_names)
|
|
313
|
+
|
|
314
|
+
_LOGGER.info(f"MiceImputer initialized. Found {len(self.cat_info)} categorical features to discretize.")
|
|
315
|
+
|
|
316
|
+
def _post_process(self, imputed_df: pd.DataFrame) -> pd.DataFrame:
|
|
317
|
+
"""
|
|
318
|
+
Applies schema-based discretization to a completed dataframe.
|
|
319
|
+
|
|
320
|
+
This method works around the behavior of `discretize_categorical_values`
|
|
321
|
+
(which returns a full int32 array) by:
|
|
322
|
+
1. Calling it on the full, ordered feature array.
|
|
323
|
+
2. Extracting *only* the valid discretized categorical columns.
|
|
324
|
+
3. Updating the original float dataframe with these integer values.
|
|
325
|
+
"""
|
|
326
|
+
# If no categorical features are defined, return the df as-is.
|
|
327
|
+
if not self.cat_info:
|
|
328
|
+
return imputed_df
|
|
329
|
+
|
|
330
|
+
try:
|
|
331
|
+
# 1. Ensure DataFrame columns match the schema order
|
|
332
|
+
# This is critical for the index-based categorical_info
|
|
333
|
+
df_ordered: pd.DataFrame = imputed_df[self.ordered_features] # type: ignore
|
|
334
|
+
|
|
335
|
+
# 2. Convert to NumPy array
|
|
336
|
+
array_ordered = df_ordered.to_numpy()
|
|
337
|
+
|
|
338
|
+
# 3. Apply discretization utility (which returns a full int32 array)
|
|
339
|
+
# This array has *correct* categorical values but *truncated* continuous values.
|
|
340
|
+
discretized_array_int32 = discretize_categorical_values(
|
|
341
|
+
array_ordered,
|
|
342
|
+
self.cat_info,
|
|
343
|
+
start_at_zero=True # Assuming 0-based indexing
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
# 4. Create a new DF from the int32 array, keeping the categorical columns.
|
|
347
|
+
df_discretized_cats = pd.DataFrame(
|
|
348
|
+
discretized_array_int32,
|
|
349
|
+
columns=self.ordered_features,
|
|
350
|
+
index=df_ordered.index # <-- Critical: align index
|
|
351
|
+
)[self.categorical_features] # <-- Select only cat features
|
|
352
|
+
|
|
353
|
+
# 5. "Rejoin": Start with a fresh copy of the *original* imputed DF (which has correct continuous floats).
|
|
354
|
+
final_df = df_ordered.copy()
|
|
355
|
+
|
|
356
|
+
# 6. Use .update() to "paste" the integer categorical values
|
|
357
|
+
# over the old float categorical values. Continuous floats are unaffected.
|
|
358
|
+
final_df.update(df_discretized_cats)
|
|
359
|
+
|
|
360
|
+
return final_df
|
|
361
|
+
|
|
362
|
+
except Exception as e:
|
|
363
|
+
_LOGGER.error(f"Failed during post-processing discretization:\n\tInput DF shape: {imputed_df.shape}\n\tSchema features: {len(self.ordered_features)}\n\tCategorical info keys: {list(self.cat_info.keys())}\n{e}")
|
|
364
|
+
raise
|
|
365
|
+
|
|
366
|
+
def _run_mice(self,
|
|
367
|
+
df: pd.DataFrame,
|
|
368
|
+
df_name: str) -> tuple[mf.ImputationKernel, list[pd.DataFrame], list[str]]:
|
|
369
|
+
"""
|
|
370
|
+
Runs the MICE kernel and applies schema-based post-processing.
|
|
371
|
+
|
|
372
|
+
Parameters:
|
|
373
|
+
df (pd.DataFrame): The input dataframe *with NaNs*. Should only contain feature columns.
|
|
374
|
+
df_name (str): The base name for the dataset.
|
|
375
|
+
|
|
376
|
+
Returns:
|
|
377
|
+
tuple[mf.ImputationKernel, list[pd.DataFrame], list[str]]:
|
|
378
|
+
- The trained MICE kernel
|
|
379
|
+
- A list of imputed and processed DataFrames
|
|
380
|
+
- A list of names for the new DataFrames
|
|
381
|
+
"""
|
|
382
|
+
# Ensure input df only contains features from the schema and is in the correct order.
|
|
383
|
+
try:
|
|
384
|
+
df_feats = df[self.ordered_features]
|
|
385
|
+
except KeyError as e:
|
|
386
|
+
_LOGGER.error(f"Input DataFrame is missing required schema columns: {e}")
|
|
387
|
+
raise
|
|
388
|
+
|
|
389
|
+
# 1. Initialize kernel
|
|
390
|
+
kernel = mf.ImputationKernel(
|
|
391
|
+
data=df_feats,
|
|
392
|
+
num_datasets=self.resulting_datasets,
|
|
393
|
+
random_state=self.random_state
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
_LOGGER.info("➡️ Schema-based MICE imputation running...")
|
|
397
|
+
|
|
398
|
+
# 2. Perform MICE
|
|
399
|
+
kernel.mice(self.iterations)
|
|
400
|
+
|
|
401
|
+
# 3. Retrieve, process, and collect datasets
|
|
402
|
+
imputed_datasets = []
|
|
403
|
+
for i in range(self.resulting_datasets):
|
|
404
|
+
# complete_data returns a pd.DataFrame
|
|
405
|
+
completed_df = kernel.complete_data(dataset=i)
|
|
406
|
+
|
|
407
|
+
# Apply our new discretization and ordering
|
|
408
|
+
processed_df = self._post_process(completed_df)
|
|
409
|
+
imputed_datasets.append(processed_df)
|
|
410
|
+
|
|
411
|
+
if not imputed_datasets:
|
|
412
|
+
_LOGGER.error("No imputed datasets were generated.")
|
|
413
|
+
raise ValueError()
|
|
414
|
+
|
|
415
|
+
# 4. Generate names
|
|
416
|
+
if self.resulting_datasets == 1:
|
|
417
|
+
imputed_dataset_names = [f"{df_name}_MICE"]
|
|
418
|
+
else:
|
|
419
|
+
imputed_dataset_names = [f"{df_name}_MICE_{i+1}" for i in range(self.resulting_datasets)]
|
|
420
|
+
|
|
421
|
+
# 5. Validate indexes
|
|
422
|
+
for imputed_df, subname in zip(imputed_datasets, imputed_dataset_names):
|
|
423
|
+
assert imputed_df.shape[0] == df.shape[0], f"❌ Row count mismatch in dataset {subname}"
|
|
424
|
+
assert all(imputed_df.index == df.index), f"❌ Index mismatch in dataset {subname}"
|
|
425
|
+
|
|
426
|
+
_LOGGER.info("Schema-based MICE imputation complete.")
|
|
427
|
+
|
|
428
|
+
return kernel, imputed_datasets, imputed_dataset_names
|
|
429
|
+
|
|
430
|
+
def run_pipeline(self,
|
|
431
|
+
df_path_or_dir: Union[str,Path],
|
|
432
|
+
save_datasets_dir: Union[str,Path],
|
|
433
|
+
save_metrics_dir: Union[str,Path],
|
|
434
|
+
):
|
|
435
|
+
"""
|
|
436
|
+
Runs the complete MICE imputation pipeline.
|
|
437
|
+
|
|
438
|
+
This method automates the entire workflow:
|
|
439
|
+
1. Loads data from a CSV file path or a directory with CSV files.
|
|
440
|
+
2. Separates features and targets based on the `FeatureSchema`.
|
|
441
|
+
3. Runs the MICE algorithm on the feature set.
|
|
442
|
+
4. Applies schema-based post-processing to discretize categorical features.
|
|
443
|
+
5. Saves the final, processed, and imputed dataset(s) (re-joined with targets) to `save_datasets_dir`.
|
|
444
|
+
6. Generates and saves convergence and distribution plots for all imputed columns to `save_metrics_dir`.
|
|
445
|
+
|
|
446
|
+
Parameters
|
|
447
|
+
----------
|
|
448
|
+
df_path_or_dir :[str,Path]
|
|
449
|
+
Path to a single CSV file or a directory containing multiple CSV files to impute.
|
|
450
|
+
save_datasets_dir : [str,Path]
|
|
451
|
+
Directory where the final imputed and processed dataset(s) will be saved as CSVs.
|
|
452
|
+
save_metrics_dir : [str,Path]
|
|
453
|
+
Directory where convergence and distribution plots will be saved.
|
|
454
|
+
"""
|
|
455
|
+
# Check paths
|
|
456
|
+
save_datasets_path = make_fullpath(save_datasets_dir, make=True)
|
|
457
|
+
save_metrics_path = make_fullpath(save_metrics_dir, make=True)
|
|
458
|
+
|
|
459
|
+
input_path = make_fullpath(df_path_or_dir)
|
|
460
|
+
if input_path.is_file():
|
|
461
|
+
all_file_paths = [input_path]
|
|
462
|
+
else:
|
|
463
|
+
all_file_paths = list(list_csv_paths(input_path).values())
|
|
464
|
+
|
|
465
|
+
for df_path in all_file_paths:
|
|
466
|
+
|
|
467
|
+
df, df_name = load_dataframe(df_path=df_path, kind="pandas")
|
|
468
|
+
|
|
469
|
+
df_features: pd.DataFrame = df[self.schema.feature_names] # type: ignore
|
|
470
|
+
df_targets = df.drop(columns=self.schema.feature_names)
|
|
471
|
+
|
|
472
|
+
imputed_column_names = _get_na_column_names(df=df_features)
|
|
473
|
+
|
|
474
|
+
kernel, imputed_datasets, imputed_dataset_names = self._run_mice(df=df_features, df_name=df_name)
|
|
475
|
+
|
|
476
|
+
save_imputed_datasets(save_dir=save_datasets_path, imputed_datasets=imputed_datasets, df_targets=df_targets, imputed_dataset_names=imputed_dataset_names)
|
|
477
|
+
|
|
478
|
+
get_convergence_diagnostic(kernel=kernel, imputed_dataset_names=imputed_dataset_names, column_names=imputed_column_names, root_dir=save_metrics_path)
|
|
479
|
+
|
|
480
|
+
get_imputed_distributions(kernel=kernel, df_name=df_name, root_dir=save_metrics_path, column_names=imputed_column_names)
|
|
481
|
+
|
|
482
|
+
|
|
281
483
|
def info():
|
|
282
484
|
_script_info(__all__)
|
ml_tools/ML_utilities.py
CHANGED
|
@@ -1,18 +1,24 @@
|
|
|
1
1
|
import pandas as pd
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
from typing import Union, Any, Optional
|
|
3
|
+
from typing import Union, Any, Optional, Dict, List, Iterable
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
4
7
|
|
|
5
8
|
from .path_manager import make_fullpath, list_subdirectories, list_files_by_extension
|
|
6
9
|
from ._script_info import _script_info
|
|
7
10
|
from ._logger import _LOGGER
|
|
8
|
-
from .keys import DatasetKeys, PytorchModelArchitectureKeys, PytorchArtifactPathKeys, SHAPKeys
|
|
11
|
+
from .keys import DatasetKeys, PytorchModelArchitectureKeys, PytorchArtifactPathKeys, SHAPKeys, UtilityKeys, PyTorchCheckpointKeys
|
|
9
12
|
from .utilities import load_dataframe
|
|
10
|
-
from .custom_logger import save_list_strings
|
|
13
|
+
from .custom_logger import save_list_strings, custom_logger
|
|
11
14
|
|
|
12
15
|
|
|
13
16
|
__all__ = [
|
|
14
17
|
"find_model_artifacts",
|
|
15
|
-
"select_features_by_shap"
|
|
18
|
+
"select_features_by_shap",
|
|
19
|
+
"get_model_parameters",
|
|
20
|
+
"inspect_pth_file",
|
|
21
|
+
"set_parameter_requires_grad"
|
|
16
22
|
]
|
|
17
23
|
|
|
18
24
|
|
|
@@ -226,5 +232,248 @@ def select_features_by_shap(
|
|
|
226
232
|
return final_features
|
|
227
233
|
|
|
228
234
|
|
|
235
|
+
def get_model_parameters(model: nn.Module, save_dir: Optional[Union[str,Path]]=None) -> Dict[str, int]:
|
|
236
|
+
"""
|
|
237
|
+
Calculates the total and trainable parameters of a PyTorch model.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
model (nn.Module): The PyTorch model to inspect.
|
|
241
|
+
save_dir: Optional directory to save the output as a JSON file.
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
Dict[str, int]: A dictionary containing:
|
|
245
|
+
- "total_params": The total number of parameters.
|
|
246
|
+
- "trainable_params": The number of trainable parameters (where requires_grad=True).
|
|
247
|
+
"""
|
|
248
|
+
total_params = sum(p.numel() for p in model.parameters())
|
|
249
|
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
250
|
+
|
|
251
|
+
report = {
|
|
252
|
+
UtilityKeys.TOTAL_PARAMS: total_params,
|
|
253
|
+
UtilityKeys.TRAINABLE_PARAMS: trainable_params
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
if save_dir is not None:
|
|
257
|
+
output_dir = make_fullpath(save_dir, make=True, enforce="directory")
|
|
258
|
+
custom_logger(data=report,
|
|
259
|
+
save_directory=output_dir,
|
|
260
|
+
log_name=UtilityKeys.MODEL_PARAMS_FILE,
|
|
261
|
+
dict_as="json")
|
|
262
|
+
|
|
263
|
+
return report
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def inspect_pth_file(
|
|
267
|
+
pth_path: Union[str, Path],
|
|
268
|
+
save_dir: Union[str, Path],
|
|
269
|
+
) -> None:
|
|
270
|
+
"""
|
|
271
|
+
Inspects a .pth file (e.g., checkpoint) and saves a human-readable
|
|
272
|
+
JSON summary of its contents.
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
pth_path (str | Path): The path to the .pth file to inspect.
|
|
276
|
+
save_dir (str | Path): The directory to save the JSON report.
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
Dict (str, Any): A dictionary containing the inspection report.
|
|
280
|
+
|
|
281
|
+
Raises:
|
|
282
|
+
ValueError: If the .pth file is empty or in an unrecognized format.
|
|
283
|
+
"""
|
|
284
|
+
# --- 1. Validate paths ---
|
|
285
|
+
pth_file = make_fullpath(pth_path, enforce="file")
|
|
286
|
+
output_dir = make_fullpath(save_dir, make=True, enforce="directory")
|
|
287
|
+
pth_name = pth_file.stem
|
|
288
|
+
|
|
289
|
+
# --- 2. Load data ---
|
|
290
|
+
try:
|
|
291
|
+
# Load onto CPU to avoid GPU memory issues
|
|
292
|
+
loaded_data = torch.load(pth_file, map_location=torch.device('cpu'))
|
|
293
|
+
except Exception as e:
|
|
294
|
+
_LOGGER.error(f"Failed to load .pth file '{pth_file}': {e}")
|
|
295
|
+
raise
|
|
296
|
+
|
|
297
|
+
# --- 3. Initialize Report ---
|
|
298
|
+
report = {
|
|
299
|
+
"top_level_type": str(type(loaded_data)),
|
|
300
|
+
"top_level_summary": {},
|
|
301
|
+
"model_state_analysis": None,
|
|
302
|
+
"notes": []
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
# --- 4. Parse loaded data ---
|
|
306
|
+
if isinstance(loaded_data, dict):
|
|
307
|
+
# --- Case 1: Loaded data is a dictionary (most common case) ---
|
|
308
|
+
# "main loop" that iterates over *everything* first.
|
|
309
|
+
for key, value in loaded_data.items():
|
|
310
|
+
key_summary = {}
|
|
311
|
+
val_type = str(type(value))
|
|
312
|
+
key_summary["type"] = val_type
|
|
313
|
+
|
|
314
|
+
if isinstance(value, torch.Tensor):
|
|
315
|
+
key_summary["shape"] = list(value.shape)
|
|
316
|
+
key_summary["dtype"] = str(value.dtype)
|
|
317
|
+
elif isinstance(value, dict):
|
|
318
|
+
key_summary["key_count"] = len(value)
|
|
319
|
+
key_summary["key_preview"] = list(value.keys())[:5]
|
|
320
|
+
elif isinstance(value, (int, float, str, bool)):
|
|
321
|
+
key_summary["value_preview"] = str(value)
|
|
322
|
+
elif isinstance(value, (list, tuple)):
|
|
323
|
+
key_summary["value_preview"] = str(value)[:100]
|
|
324
|
+
|
|
325
|
+
report["top_level_summary"][key] = key_summary
|
|
326
|
+
|
|
327
|
+
# Now, try to find the model state_dict within the dict
|
|
328
|
+
if PyTorchCheckpointKeys.MODEL_STATE in loaded_data and isinstance(loaded_data[PyTorchCheckpointKeys.MODEL_STATE], dict):
|
|
329
|
+
report["notes"].append(f"Found standard checkpoint key: '{PyTorchCheckpointKeys.MODEL_STATE}'. Analyzing as model state_dict.")
|
|
330
|
+
state_dict = loaded_data[PyTorchCheckpointKeys.MODEL_STATE]
|
|
331
|
+
report["model_state_analysis"] = _generate_weight_report(state_dict)
|
|
332
|
+
|
|
333
|
+
elif all(isinstance(v, torch.Tensor) for v in loaded_data.values()):
|
|
334
|
+
report["notes"].append("File dictionary contains only tensors. Analyzing entire dictionary as model state_dict.")
|
|
335
|
+
state_dict = loaded_data
|
|
336
|
+
report["model_state_analysis"] = _generate_weight_report(state_dict)
|
|
337
|
+
|
|
338
|
+
else:
|
|
339
|
+
report["notes"].append("Could not identify a single model state_dict. See top_level_summary for all contents. No detailed weight analysis will be performed.")
|
|
340
|
+
|
|
341
|
+
elif isinstance(loaded_data, nn.Module):
|
|
342
|
+
# --- Case 2: Loaded data is a full pickled model ---
|
|
343
|
+
# _LOGGER.warning("Loading a full, pickled nn.Module is not recommended. Inspecting its state_dict().")
|
|
344
|
+
report["notes"].append("File is a full, pickled nn.Module. This is not recommended. Extracting state_dict() for analysis.")
|
|
345
|
+
state_dict = loaded_data.state_dict()
|
|
346
|
+
report["model_state_analysis"] = _generate_weight_report(state_dict)
|
|
347
|
+
|
|
348
|
+
else:
|
|
349
|
+
# --- Case 3: Unrecognized format (e.g., single tensor, list) ---
|
|
350
|
+
_LOGGER.error(f"Could not parse .pth file. Loaded data is of type {type(loaded_data)}, not a dict or nn.Module.")
|
|
351
|
+
raise ValueError()
|
|
352
|
+
|
|
353
|
+
# --- 5. Save Report ---
|
|
354
|
+
custom_logger(data=report,
|
|
355
|
+
save_directory=output_dir,
|
|
356
|
+
log_name=UtilityKeys.PTH_FILE + pth_name,
|
|
357
|
+
dict_as="json")
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def _generate_weight_report(state_dict: dict) -> dict:
|
|
361
|
+
"""
|
|
362
|
+
Internal helper to analyze a state_dict and return a structured report.
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
state_dict (dict): The model state_dict to analyze.
|
|
366
|
+
|
|
367
|
+
Returns:
|
|
368
|
+
dict: A report containing total parameters and a per-parameter breakdown.
|
|
369
|
+
"""
|
|
370
|
+
weight_report = {}
|
|
371
|
+
total_params = 0
|
|
372
|
+
if not isinstance(state_dict, dict):
|
|
373
|
+
_LOGGER.warning(f"Attempted to generate weight report on non-dict type: {type(state_dict)}")
|
|
374
|
+
return {"error": "Input was not a dictionary."}
|
|
375
|
+
|
|
376
|
+
for key, tensor in state_dict.items():
|
|
377
|
+
if not isinstance(tensor, torch.Tensor):
|
|
378
|
+
_LOGGER.warning(f"Skipping key '{key}' in state_dict: value is not a tensor (type: {type(tensor)}).")
|
|
379
|
+
weight_report[key] = {
|
|
380
|
+
"type": str(type(tensor)),
|
|
381
|
+
"value_preview": str(tensor)[:50] # Show a preview
|
|
382
|
+
}
|
|
383
|
+
continue
|
|
384
|
+
weight_report[key] = {
|
|
385
|
+
"shape": list(tensor.shape),
|
|
386
|
+
"dtype": str(tensor.dtype),
|
|
387
|
+
"requires_grad": tensor.requires_grad,
|
|
388
|
+
"num_elements": tensor.numel()
|
|
389
|
+
}
|
|
390
|
+
total_params += tensor.numel()
|
|
391
|
+
|
|
392
|
+
return {
|
|
393
|
+
"total_parameters": total_params,
|
|
394
|
+
"parameter_key_count": len(weight_report),
|
|
395
|
+
"parameters": weight_report
|
|
396
|
+
}
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
def set_parameter_requires_grad(
|
|
400
|
+
model: nn.Module,
|
|
401
|
+
unfreeze_last_n_params: int,
|
|
402
|
+
) -> int:
|
|
403
|
+
"""
|
|
404
|
+
Freezes or unfreezes parameters in a model based on unfreeze_last_n_params.
|
|
405
|
+
|
|
406
|
+
- N = 0: Freezes ALL parameters.
|
|
407
|
+
- N > 0 and N < total: Freezes ALL parameters, then unfreezes the last N.
|
|
408
|
+
- N >= total: Unfreezes ALL parameters.
|
|
409
|
+
|
|
410
|
+
Note: 'N' refers to individual parameter tensors (e.g., `layer.weight`
|
|
411
|
+
or `layer.bias`), not modules or layers. For example, to unfreeze
|
|
412
|
+
the final nn.Linear layer, you would use N=2 (for its weight and bias).
|
|
413
|
+
|
|
414
|
+
Args:
|
|
415
|
+
model (nn.Module): The model to modify.
|
|
416
|
+
unfreeze_last_n_params (int):
|
|
417
|
+
The number of parameter tensors to unfreeze, starting from
|
|
418
|
+
the end of the model.
|
|
419
|
+
|
|
420
|
+
Returns:
|
|
421
|
+
int: The total number of individual parameters (elements) that were set to `requires_grad=True`.
|
|
422
|
+
"""
|
|
423
|
+
if unfreeze_last_n_params < 0:
|
|
424
|
+
_LOGGER.error(f"unfreeze_last_n_params must be >= 0, but got {unfreeze_last_n_params}")
|
|
425
|
+
raise ValueError()
|
|
426
|
+
|
|
427
|
+
# --- Step 1: Get all parameter tensors ---
|
|
428
|
+
all_params = list(model.parameters())
|
|
429
|
+
total_param_tensors = len(all_params)
|
|
430
|
+
|
|
431
|
+
# --- Case 1: N = 0 (Freeze ALL parameters) ---
|
|
432
|
+
# early exit for the "freeze all" case.
|
|
433
|
+
if unfreeze_last_n_params == 0:
|
|
434
|
+
params_frozen = _set_params_grad(all_params, requires_grad=False)
|
|
435
|
+
_LOGGER.warning(f"Froze all {total_param_tensors} parameter tensors ({params_frozen} total elements).")
|
|
436
|
+
return 0 # 0 parameters unfrozen
|
|
437
|
+
|
|
438
|
+
# --- Case 2: N >= total (Unfreeze ALL parameters) ---
|
|
439
|
+
if unfreeze_last_n_params >= total_param_tensors:
|
|
440
|
+
if unfreeze_last_n_params > total_param_tensors:
|
|
441
|
+
_LOGGER.warning(f"Requested to unfreeze {unfreeze_last_n_params} params, but model only has {total_param_tensors}. Unfreezing all.")
|
|
442
|
+
|
|
443
|
+
params_unfrozen = _set_params_grad(all_params, requires_grad=True)
|
|
444
|
+
_LOGGER.info(f"Unfroze all {total_param_tensors} parameter tensors ({params_unfrozen} total elements) for training.")
|
|
445
|
+
return params_unfrozen
|
|
446
|
+
|
|
447
|
+
# --- Case 3: 0 < N < total (Standard: Freeze all, unfreeze last N) ---
|
|
448
|
+
# Freeze ALL
|
|
449
|
+
params_frozen = _set_params_grad(all_params, requires_grad=False)
|
|
450
|
+
_LOGGER.info(f"Froze {params_frozen} parameters.")
|
|
451
|
+
|
|
452
|
+
# Unfreeze the last N
|
|
453
|
+
params_to_unfreeze = all_params[-unfreeze_last_n_params:]
|
|
454
|
+
|
|
455
|
+
# these are all False, so the helper will set them to True
|
|
456
|
+
params_unfrozen = _set_params_grad(params_to_unfreeze, requires_grad=True)
|
|
457
|
+
|
|
458
|
+
_LOGGER.info(f"Unfroze the last {unfreeze_last_n_params} parameter tensors ({params_unfrozen} total elements) for training.")
|
|
459
|
+
|
|
460
|
+
return params_unfrozen
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
def _set_params_grad(
|
|
464
|
+
params: Iterable[nn.Parameter],
|
|
465
|
+
requires_grad: bool
|
|
466
|
+
) -> int:
|
|
467
|
+
"""
|
|
468
|
+
A helper function to set the `requires_grad` attribute for an iterable
|
|
469
|
+
of parameters and return the total number of elements changed.
|
|
470
|
+
"""
|
|
471
|
+
params_changed = 0
|
|
472
|
+
for param in params:
|
|
473
|
+
if param.requires_grad != requires_grad:
|
|
474
|
+
param.requires_grad = requires_grad
|
|
475
|
+
params_changed += param.numel()
|
|
476
|
+
return params_changed
|
|
477
|
+
|
|
229
478
|
def info():
|
|
230
479
|
_script_info(__all__)
|
ml_tools/custom_logger.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from pathlib import Path
|
|
2
2
|
from datetime import datetime
|
|
3
|
-
from typing import Union, List, Dict, Any
|
|
3
|
+
from typing import Union, List, Dict, Any, Literal
|
|
4
4
|
import traceback
|
|
5
5
|
import json
|
|
6
6
|
import csv
|
|
@@ -29,6 +29,7 @@ def custom_logger(
|
|
|
29
29
|
],
|
|
30
30
|
save_directory: Union[str, Path],
|
|
31
31
|
log_name: str,
|
|
32
|
+
dict_as: Literal['auto', 'json', 'csv'] = 'auto',
|
|
32
33
|
) -> None:
|
|
33
34
|
"""
|
|
34
35
|
Logs various data types to corresponding output formats:
|
|
@@ -36,10 +37,10 @@ def custom_logger(
|
|
|
36
37
|
- list[Any] → .txt
|
|
37
38
|
Each element is written on a new line.
|
|
38
39
|
|
|
39
|
-
- dict[str, list[Any]] → .csv
|
|
40
|
+
- dict[str, list[Any]] → .csv (if dict_as='auto' or 'csv')
|
|
40
41
|
Dictionary is treated as tabular data; keys become columns, values become rows.
|
|
41
42
|
|
|
42
|
-
- dict[str, scalar] → .json
|
|
43
|
+
- dict[str, scalar] → .json (if dict_as='auto' or 'json')
|
|
43
44
|
Dictionary is treated as structured data and serialized as JSON.
|
|
44
45
|
|
|
45
46
|
- str → .log
|
|
@@ -52,26 +53,43 @@ def custom_logger(
|
|
|
52
53
|
data: The data to be logged. Must be one of the supported types.
|
|
53
54
|
save_directory: Directory where the log will be saved. Created if it does not exist.
|
|
54
55
|
log_name: Base name for the log file. Timestamp will be appended automatically.
|
|
56
|
+
dict_as ('auto'|'json'|'csv'):
|
|
57
|
+
- 'auto': Guesses format (JSON or CSV) based on dictionary content.
|
|
58
|
+
- 'json': Forces .json format for any dictionary.
|
|
59
|
+
- 'csv': Forces .csv format. Will fail if dict values are not all lists.
|
|
55
60
|
|
|
56
61
|
Raises:
|
|
57
62
|
ValueError: If the data type is unsupported.
|
|
58
63
|
"""
|
|
59
64
|
try:
|
|
65
|
+
if not isinstance(data, BaseException) and not data:
|
|
66
|
+
_LOGGER.warning("Empty data received. No log file will be saved.")
|
|
67
|
+
return
|
|
68
|
+
|
|
60
69
|
save_path = make_fullpath(save_directory, make=True)
|
|
61
70
|
|
|
62
71
|
timestamp = datetime.now().strftime(r"%Y%m%d_%H%M%S")
|
|
63
72
|
log_name = sanitize_filename(log_name)
|
|
64
73
|
|
|
65
74
|
base_path = save_path / f"{log_name}_{timestamp}"
|
|
66
|
-
|
|
75
|
+
|
|
76
|
+
# Router
|
|
67
77
|
if isinstance(data, list):
|
|
68
78
|
_log_list_to_txt(data, base_path.with_suffix(".txt"))
|
|
69
79
|
|
|
70
80
|
elif isinstance(data, dict):
|
|
71
|
-
if
|
|
72
|
-
_log_dict_to_csv(data, base_path.with_suffix(".csv"))
|
|
73
|
-
else:
|
|
81
|
+
if dict_as == 'json':
|
|
74
82
|
_log_dict_to_json(data, base_path.with_suffix(".json"))
|
|
83
|
+
|
|
84
|
+
elif dict_as == 'csv':
|
|
85
|
+
# This will raise a ValueError if data is not all lists
|
|
86
|
+
_log_dict_to_csv(data, base_path.with_suffix(".csv"))
|
|
87
|
+
|
|
88
|
+
else: # 'auto' mode
|
|
89
|
+
if all(isinstance(v, list) for v in data.values()):
|
|
90
|
+
_log_dict_to_csv(data, base_path.with_suffix(".csv"))
|
|
91
|
+
else:
|
|
92
|
+
_log_dict_to_json(data, base_path.with_suffix(".json"))
|
|
75
93
|
|
|
76
94
|
elif isinstance(data, str):
|
|
77
95
|
_log_string_to_log(data, base_path.with_suffix(".log"))
|
|
@@ -83,7 +101,7 @@ def custom_logger(
|
|
|
83
101
|
_LOGGER.error("Unsupported data type. Must be list, dict, str, or BaseException.")
|
|
84
102
|
raise ValueError()
|
|
85
103
|
|
|
86
|
-
_LOGGER.info(f"Log saved
|
|
104
|
+
_LOGGER.info(f"Log saved as: '{base_path.name}'")
|
|
87
105
|
|
|
88
106
|
except Exception:
|
|
89
107
|
_LOGGER.exception(f"Log not saved.")
|
ml_tools/keys.py
CHANGED
|
@@ -80,6 +80,14 @@ class PyTorchCheckpointKeys:
|
|
|
80
80
|
BEST_SCORE = "best_score"
|
|
81
81
|
|
|
82
82
|
|
|
83
|
+
class UtilityKeys:
|
|
84
|
+
"""Keys used for utility modules"""
|
|
85
|
+
MODEL_PARAMS_FILE = "model_parameters"
|
|
86
|
+
TOTAL_PARAMS = "Total Parameters"
|
|
87
|
+
TRAINABLE_PARAMS = "Trainable Parameters"
|
|
88
|
+
PTH_FILE = "pth report "
|
|
89
|
+
|
|
90
|
+
|
|
83
91
|
class _OneHotOtherPlaceholder:
|
|
84
92
|
"""Used internally by GUI_tools."""
|
|
85
93
|
OTHER_GUI = "OTHER"
|
ml_tools/utilities.py
CHANGED
|
@@ -7,16 +7,19 @@ from typing import Literal, Union, Optional, Any, Iterator, Tuple, overload
|
|
|
7
7
|
from .path_manager import sanitize_filename, make_fullpath, list_csv_paths
|
|
8
8
|
from ._script_info import _script_info
|
|
9
9
|
from ._logger import _LOGGER
|
|
10
|
+
from ._schema import FeatureSchema
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
# Keep track of available tools
|
|
13
14
|
__all__ = [
|
|
14
15
|
"load_dataframe",
|
|
15
16
|
"load_dataframe_greedy",
|
|
17
|
+
"load_dataframe_with_schema",
|
|
16
18
|
"yield_dataframes_from_dir",
|
|
17
19
|
"merge_dataframes",
|
|
18
20
|
"save_dataframe_filename",
|
|
19
21
|
"save_dataframe",
|
|
22
|
+
"save_dataframe_with_schema",
|
|
20
23
|
"distribute_dataset_by_target",
|
|
21
24
|
"train_dataset_orchestrator",
|
|
22
25
|
"train_dataset_yielder"
|
|
@@ -174,6 +177,68 @@ def load_dataframe_greedy(directory: Union[str, Path],
|
|
|
174
177
|
return df
|
|
175
178
|
|
|
176
179
|
|
|
180
|
+
def load_dataframe_with_schema(
|
|
181
|
+
df_path: Union[str, Path],
|
|
182
|
+
schema: "FeatureSchema",
|
|
183
|
+
all_strings: bool = False,
|
|
184
|
+
) -> Tuple[pd.DataFrame, str]:
|
|
185
|
+
"""
|
|
186
|
+
Loads a CSV file into a Pandas DataFrame, strictly validating its
|
|
187
|
+
feature columns against a FeatureSchema.
|
|
188
|
+
|
|
189
|
+
This function wraps `load_dataframe`. After loading, it validates
|
|
190
|
+
that the first N columns of the DataFrame (where N =
|
|
191
|
+
len(schema.feature_names)) contain *exactly* the set of features
|
|
192
|
+
specified in the schema.
|
|
193
|
+
|
|
194
|
+
- If the columns are present but out of order, they are reordered.
|
|
195
|
+
- If any required feature is missing from the first N columns, it fails.
|
|
196
|
+
- If any extra column is found within the first N columns, it fails.
|
|
197
|
+
|
|
198
|
+
Columns *after* the first N are considered target columns and are
|
|
199
|
+
logged for verification.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
df_path (str, Path):
|
|
203
|
+
The path to the CSV file.
|
|
204
|
+
schema (FeatureSchema):
|
|
205
|
+
The schema object to validate against.
|
|
206
|
+
all_strings (bool):
|
|
207
|
+
If True, loads all columns as string data types.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
(Tuple[pd.DataFrame, str]):
|
|
211
|
+
A tuple containing the loaded, validated (and possibly
|
|
212
|
+
reordered) pandas DataFrame and the base name of the file.
|
|
213
|
+
|
|
214
|
+
Raises:
|
|
215
|
+
ValueError:
|
|
216
|
+
- If the DataFrame is missing columns required by the schema
|
|
217
|
+
within its first N columns.
|
|
218
|
+
- If the DataFrame's first N columns contain unexpected
|
|
219
|
+
columns that are not in the schema.
|
|
220
|
+
FileNotFoundError:
|
|
221
|
+
If the file does not exist at the given path.
|
|
222
|
+
"""
|
|
223
|
+
# Step 1: Load the dataframe using the original function
|
|
224
|
+
try:
|
|
225
|
+
df, df_name = load_dataframe(
|
|
226
|
+
df_path=df_path,
|
|
227
|
+
use_columns=None, # Load all columns for validation
|
|
228
|
+
kind="pandas",
|
|
229
|
+
all_strings=all_strings,
|
|
230
|
+
verbose=True
|
|
231
|
+
)
|
|
232
|
+
except Exception as e:
|
|
233
|
+
_LOGGER.error(f"Failed during initial load for schema validation: {e}")
|
|
234
|
+
raise e
|
|
235
|
+
|
|
236
|
+
# Step 2: Call the helper to validate and reorder
|
|
237
|
+
df_validated = _validate_and_reorder_schema(df=df, schema=schema)
|
|
238
|
+
|
|
239
|
+
return df_validated, df_name
|
|
240
|
+
|
|
241
|
+
|
|
177
242
|
def yield_dataframes_from_dir(datasets_dir: Union[str,Path], verbose: bool=True):
|
|
178
243
|
"""
|
|
179
244
|
Iterates over all CSV files in a given directory, loading each into a Pandas DataFrame.
|
|
@@ -330,6 +395,52 @@ def save_dataframe(df: Union[pd.DataFrame, pl.DataFrame], full_path: Path):
|
|
|
330
395
|
filename=full_path.name)
|
|
331
396
|
|
|
332
397
|
|
|
398
|
+
def save_dataframe_with_schema(
|
|
399
|
+
df: pd.DataFrame,
|
|
400
|
+
full_path: Path,
|
|
401
|
+
schema: "FeatureSchema"
|
|
402
|
+
) -> None:
|
|
403
|
+
"""
|
|
404
|
+
Saves a pandas DataFrame to a CSV, strictly enforcing that the
|
|
405
|
+
first N columns match the FeatureSchema.
|
|
406
|
+
|
|
407
|
+
This function validates that the first N columns of the DataFrame
|
|
408
|
+
(where N = len(schema.feature_names)) contain *exactly* the set
|
|
409
|
+
of features specified in the schema.
|
|
410
|
+
|
|
411
|
+
- If the columns are present but out of order, they are reordered.
|
|
412
|
+
- If any required feature is missing from the first N columns, it fails.
|
|
413
|
+
- If any extra column is found within the first N columns, it fails.
|
|
414
|
+
|
|
415
|
+
Columns *after* the first N are considered target columns and are
|
|
416
|
+
logged for verification.
|
|
417
|
+
|
|
418
|
+
Args:
|
|
419
|
+
df (pd.DataFrame):
|
|
420
|
+
The DataFrame to save.
|
|
421
|
+
full_path (Path):
|
|
422
|
+
The complete file path where the DataFrame will be saved.
|
|
423
|
+
schema (FeatureSchema):
|
|
424
|
+
The schema object to validate against.
|
|
425
|
+
|
|
426
|
+
Raises:
|
|
427
|
+
ValueError:
|
|
428
|
+
- If the DataFrame is missing columns required by the schema
|
|
429
|
+
within its first N columns.
|
|
430
|
+
- If the DataFrame's first N columns contain unexpected
|
|
431
|
+
columns that are not in the schema.
|
|
432
|
+
"""
|
|
433
|
+
if not isinstance(full_path, Path) or not full_path.suffix.endswith(".csv"):
|
|
434
|
+
_LOGGER.error('A path object pointing to a .csv file must be provided.')
|
|
435
|
+
raise ValueError()
|
|
436
|
+
|
|
437
|
+
# Call the helper to validate and reorder
|
|
438
|
+
df_to_save = _validate_and_reorder_schema(df=df, schema=schema)
|
|
439
|
+
|
|
440
|
+
# Call the original save function
|
|
441
|
+
save_dataframe(df=df_to_save, full_path=full_path)
|
|
442
|
+
|
|
443
|
+
|
|
333
444
|
def distribute_dataset_by_target(
|
|
334
445
|
df_or_path: Union[pd.DataFrame, str, Path],
|
|
335
446
|
target_columns: list[str],
|
|
@@ -442,5 +553,72 @@ def train_dataset_yielder(
|
|
|
442
553
|
yield (df_features, df_target, feature_names, target_col)
|
|
443
554
|
|
|
444
555
|
|
|
556
|
+
def _validate_and_reorder_schema(
|
|
557
|
+
df: pd.DataFrame,
|
|
558
|
+
schema: "FeatureSchema"
|
|
559
|
+
) -> pd.DataFrame:
|
|
560
|
+
"""
|
|
561
|
+
Internal helper to validate and reorder a DataFrame against a schema.
|
|
562
|
+
|
|
563
|
+
Checks for missing, extra, and out-of-order feature columns
|
|
564
|
+
(the first N columns). Returns a reordered DataFrame if necessary.
|
|
565
|
+
Logs all actions.
|
|
566
|
+
|
|
567
|
+
Raises:
|
|
568
|
+
ValueError: If validation fails.
|
|
569
|
+
"""
|
|
570
|
+
# Get schema and DataFrame column info
|
|
571
|
+
expected_features = list(schema.feature_names)
|
|
572
|
+
expected_set = set(expected_features)
|
|
573
|
+
n_features = len(expected_features)
|
|
574
|
+
|
|
575
|
+
all_df_columns = df.columns.to_list()
|
|
576
|
+
|
|
577
|
+
# --- Strict Validation ---
|
|
578
|
+
|
|
579
|
+
# 0. Check if DataFrame is long enough
|
|
580
|
+
if len(all_df_columns) < n_features:
|
|
581
|
+
_LOGGER.error(f"DataFrame has only {len(all_df_columns)} columns, but schema requires {n_features} features.")
|
|
582
|
+
raise ValueError()
|
|
583
|
+
|
|
584
|
+
df_feature_cols = all_df_columns[:n_features]
|
|
585
|
+
df_feature_set = set(df_feature_cols)
|
|
586
|
+
df_target_cols = all_df_columns[n_features:]
|
|
587
|
+
|
|
588
|
+
# 1. Check for missing features
|
|
589
|
+
missing_from_df = expected_set - df_feature_set
|
|
590
|
+
if missing_from_df:
|
|
591
|
+
_LOGGER.error(f"DataFrame's first {n_features} columns are missing required schema features: {missing_from_df}")
|
|
592
|
+
raise ValueError()
|
|
593
|
+
|
|
594
|
+
# 2. Check for extra (unexpected) features
|
|
595
|
+
extra_in_df = df_feature_set - expected_set
|
|
596
|
+
if extra_in_df:
|
|
597
|
+
_LOGGER.error(f"DataFrame's first {n_features} columns contain unexpected columns: {extra_in_df}")
|
|
598
|
+
raise ValueError()
|
|
599
|
+
|
|
600
|
+
# --- Reordering ---
|
|
601
|
+
|
|
602
|
+
df_to_process = df
|
|
603
|
+
|
|
604
|
+
# If we pass validation, the sets are equal. Now check order.
|
|
605
|
+
if df_feature_cols == expected_features:
|
|
606
|
+
_LOGGER.info("DataFrame feature columns already match schema order.")
|
|
607
|
+
else:
|
|
608
|
+
_LOGGER.warning("DataFrame feature columns do not match schema order. Reordering...")
|
|
609
|
+
|
|
610
|
+
# Rebuild the DataFrame with the correct feature order + target columns
|
|
611
|
+
new_order = expected_features + df_target_cols
|
|
612
|
+
df_to_process = df[new_order]
|
|
613
|
+
|
|
614
|
+
# Log the presumed target columns for user verification
|
|
615
|
+
if not df_target_cols:
|
|
616
|
+
_LOGGER.warning(f"No target columns were found after index {n_features-1}.")
|
|
617
|
+
else:
|
|
618
|
+
_LOGGER.info(f"Presumed Target Columns: {df_target_cols}")
|
|
619
|
+
|
|
620
|
+
return df_to_process # type: ignore
|
|
621
|
+
|
|
622
|
+
|
|
445
623
|
def info():
|
|
446
624
|
_script_info(__all__)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|