dragon-ml-toolbox 1.3.2__py3-none-any.whl → 1.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-1.3.2.dist-info → dragon_ml_toolbox-1.4.1.dist-info}/METADATA +19 -2
- dragon_ml_toolbox-1.4.1.dist-info/RECORD +19 -0
- ml_tools/MICE_imputation.py +24 -6
- ml_tools/VIF_factor.py +224 -0
- ml_tools/data_exploration.py +74 -286
- ml_tools/datasetmaster.py +13 -1
- ml_tools/ensemble_learning.py +128 -129
- ml_tools/handle_excel.py +32 -9
- ml_tools/logger.py +10 -1
- ml_tools/particle_swarm_optimization.py +71 -34
- ml_tools/pytorch_models.py +13 -1
- ml_tools/trainer.py +10 -30
- ml_tools/utilities.py +122 -14
- ml_tools/vision_helpers.py +14 -1
- dragon_ml_toolbox-1.3.2.dist-info/RECORD +0 -18
- {dragon_ml_toolbox-1.3.2.dist-info → dragon_ml_toolbox-1.4.1.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-1.3.2.dist-info → dragon_ml_toolbox-1.4.1.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-1.3.2.dist-info → dragon_ml_toolbox-1.4.1.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-1.3.2.dist-info → dragon_ml_toolbox-1.4.1.dist-info}/top_level.txt +0 -0
ml_tools/ensemble_learning.py
CHANGED
|
@@ -21,7 +21,7 @@ from sklearn.preprocessing import StandardScaler, MaxAbsScaler, MinMaxScaler
|
|
|
21
21
|
from sklearn.metrics import accuracy_score, classification_report, ConfusionMatrixDisplay, mean_absolute_error, mean_squared_error, r2_score, roc_curve, roc_auc_score
|
|
22
22
|
import shap
|
|
23
23
|
|
|
24
|
-
from .utilities import yield_dataframes_from_dir
|
|
24
|
+
from .utilities import yield_dataframes_from_dir, sanitize_filename
|
|
25
25
|
|
|
26
26
|
import warnings # Ignore warnings
|
|
27
27
|
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
|
@@ -139,8 +139,9 @@ def get_models(task: Literal["classification", "regression"], random_state: int=
|
|
|
139
139
|
|
|
140
140
|
###### 3. Process Dataset ######
|
|
141
141
|
# function to split data into train and test
|
|
142
|
-
def _split_data(features, target, test_size, random_state):
|
|
143
|
-
X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=test_size, random_state=random_state,
|
|
142
|
+
def _split_data(features, target, test_size, random_state, task):
|
|
143
|
+
X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=test_size, random_state=random_state,
|
|
144
|
+
stratify=target if task=="classification" else None)
|
|
144
145
|
return X_train, X_test, y_train, y_test
|
|
145
146
|
|
|
146
147
|
# function to standardize the data
|
|
@@ -176,7 +177,7 @@ def _resample(X_train_scaled: np.ndarray, y_train: pd.Series,
|
|
|
176
177
|
else:
|
|
177
178
|
raise ValueError(f"Invalid resampling strategy: {strategy}")
|
|
178
179
|
|
|
179
|
-
X_res, y_res = resample_algorithm.fit_resample(X_train_scaled, y_train)
|
|
180
|
+
X_res, y_res, *_ = resample_algorithm.fit_resample(X_train_scaled, y_train)
|
|
180
181
|
return X_res, y_res
|
|
181
182
|
|
|
182
183
|
# DATASET PIPELINE
|
|
@@ -199,7 +200,7 @@ def dataset_pipeline(df_features: pd.DataFrame, df_target: pd.Series, task: Lite
|
|
|
199
200
|
print(f"\tUnique values for '{df_target.name}': {unique_values}")
|
|
200
201
|
|
|
201
202
|
#Train test split
|
|
202
|
-
X_train, X_test, y_train, y_test = _split_data(features=df_features, target=df_target, test_size=test_size, random_state=random_state)
|
|
203
|
+
X_train, X_test, y_train, y_test = _split_data(features=df_features, target=df_target, test_size=test_size, random_state=random_state, task=task)
|
|
203
204
|
|
|
204
205
|
#DEBUG
|
|
205
206
|
if debug:
|
|
@@ -244,7 +245,9 @@ def _local_directories(model_name: str, dataset_id: str, save_dir: str):
|
|
|
244
245
|
|
|
245
246
|
# save model
|
|
246
247
|
def _save_model(trained_model, model_name: str, target_name:str, feature_names: list[str], save_directory: str, scaler_object: Union[StandardScaler, MinMaxScaler, MaxAbsScaler]):
|
|
247
|
-
|
|
248
|
+
#Sanitize filenames to save
|
|
249
|
+
sanitized_target_name = sanitize_filename(target_name)
|
|
250
|
+
full_path = os.path.join(save_directory, f"{model_name}_{sanitized_target_name}.joblib")
|
|
248
251
|
joblib.dump({'model': trained_model, 'scaler':scaler_object, 'feature_names': feature_names, 'target_name':target_name}, full_path)
|
|
249
252
|
|
|
250
253
|
# function to evaluate the model and save metrics (Classification)
|
|
@@ -297,7 +300,8 @@ def evaluate_model_classification(
|
|
|
297
300
|
)
|
|
298
301
|
|
|
299
302
|
# Save text report
|
|
300
|
-
|
|
303
|
+
sanitized_target_id = sanitize_filename(target_id)
|
|
304
|
+
report_path = os.path.join(save_dir, f"Classification_Report_{sanitized_target_id}.txt")
|
|
301
305
|
with open(report_path, "w") as f:
|
|
302
306
|
f.write(f"{model_name} - {target_id}\t\tAccuracy: {accuracy:.2f}\n")
|
|
303
307
|
f.write("Classification Report:\n")
|
|
@@ -327,7 +331,7 @@ def evaluate_model_classification(
|
|
|
327
331
|
text.set_fontsize(title_fontsize+4)
|
|
328
332
|
|
|
329
333
|
fig.tight_layout()
|
|
330
|
-
fig_path = os.path.join(save_dir, f"Confusion_Matrix_{
|
|
334
|
+
fig_path = os.path.join(save_dir, f"Confusion_Matrix_{sanitized_target_id}.svg")
|
|
331
335
|
fig.savefig(fig_path, format="svg", bbox_inches="tight")
|
|
332
336
|
plt.close(fig)
|
|
333
337
|
|
|
@@ -343,8 +347,7 @@ def plot_roc_curve(
|
|
|
343
347
|
color: str = "darkorange",
|
|
344
348
|
figure_size: tuple = (10, 10),
|
|
345
349
|
linewidth: int = 2,
|
|
346
|
-
|
|
347
|
-
label_fontsize: int = 24,
|
|
350
|
+
base_fontsize: int = 24,
|
|
348
351
|
input_features: Optional[np.ndarray] = None,
|
|
349
352
|
) -> plt.Figure: # type: ignore
|
|
350
353
|
"""
|
|
@@ -402,20 +405,22 @@ def plot_roc_curve(
|
|
|
402
405
|
ax.plot(fpr, tpr, color=color, lw=linewidth, label=f"AUC = {auc_score:.2f}")
|
|
403
406
|
ax.plot([0, 1], [0, 1], color="gray", linestyle="--", lw=1)
|
|
404
407
|
|
|
405
|
-
ax.set_title(f"{model_name} - {target_name}", fontsize=
|
|
406
|
-
ax.set_xlabel("False Positive Rate", fontsize=
|
|
407
|
-
ax.set_ylabel("True Positive Rate", fontsize=
|
|
408
|
-
ax.tick_params(axis='both', labelsize=
|
|
409
|
-
ax.legend(loc="lower right", fontsize=
|
|
408
|
+
ax.set_title(f"{model_name} - {target_name}", fontsize=base_fontsize)
|
|
409
|
+
ax.set_xlabel("False Positive Rate", fontsize=base_fontsize)
|
|
410
|
+
ax.set_ylabel("True Positive Rate", fontsize=base_fontsize)
|
|
411
|
+
ax.tick_params(axis='both', labelsize=base_fontsize)
|
|
412
|
+
ax.legend(loc="lower right", fontsize=base_fontsize)
|
|
410
413
|
ax.grid(True)
|
|
411
414
|
|
|
412
415
|
# Save figure
|
|
413
416
|
os.makedirs(save_directory, exist_ok=True)
|
|
414
|
-
|
|
417
|
+
sanitized_target_name = sanitize_filename(target_name)
|
|
418
|
+
save_path = os.path.join(save_directory, f"ROC_{sanitized_target_name}.svg")
|
|
415
419
|
fig.savefig(save_path, bbox_inches="tight", format="svg")
|
|
416
420
|
|
|
417
421
|
return fig
|
|
418
422
|
|
|
423
|
+
|
|
419
424
|
# function to evaluate the model and save metrics (Regression)
|
|
420
425
|
def evaluate_model_regression(model, model_name: str,
|
|
421
426
|
save_dir: str,
|
|
@@ -423,8 +428,7 @@ def evaluate_model_regression(model, model_name: str,
|
|
|
423
428
|
target_id: str,
|
|
424
429
|
figure_size: tuple = (12, 8),
|
|
425
430
|
alpha_transparency: float = 0.5,
|
|
426
|
-
|
|
427
|
-
normal_fontsize: int = 24):
|
|
431
|
+
base_fontsize: int = 24):
|
|
428
432
|
# Generate predictions
|
|
429
433
|
y_pred = model.predict(x_test_scaled)
|
|
430
434
|
|
|
@@ -435,7 +439,8 @@ def evaluate_model_regression(model, model_name: str,
|
|
|
435
439
|
r2 = r2_score(single_y_test, y_pred)
|
|
436
440
|
|
|
437
441
|
# Create formatted report
|
|
438
|
-
|
|
442
|
+
sanitized_target_id = sanitize_filename(target_id)
|
|
443
|
+
report_path = os.path.join(save_dir, f"Regression_Report_{sanitized_target_id}.txt")
|
|
439
444
|
with open(report_path, "w") as f:
|
|
440
445
|
f.write(f"{model_name} - {target_id} Regression Performance\n")
|
|
441
446
|
f.write(f"Mean Absolute Error (MAE): {mae:.4f}\n")
|
|
@@ -448,12 +453,12 @@ def evaluate_model_regression(model, model_name: str,
|
|
|
448
453
|
plt.figure(figsize=figure_size)
|
|
449
454
|
plt.scatter(y_pred, residuals, alpha=alpha_transparency)
|
|
450
455
|
plt.axhline(0, color='red', linestyle='--')
|
|
451
|
-
plt.xlabel("Predicted Values", fontsize=
|
|
452
|
-
plt.ylabel("Residuals", fontsize=
|
|
453
|
-
plt.title(f"{model_name} - Residual Plot for {target_id}", fontsize=
|
|
456
|
+
plt.xlabel("Predicted Values", fontsize=base_fontsize)
|
|
457
|
+
plt.ylabel("Residuals", fontsize=base_fontsize)
|
|
458
|
+
plt.title(f"{model_name} - Residual Plot for {target_id}", fontsize=base_fontsize)
|
|
454
459
|
plt.grid(True)
|
|
455
460
|
plt.tight_layout()
|
|
456
|
-
plt.savefig(os.path.join(save_dir, f"Residual_Plot_{
|
|
461
|
+
plt.savefig(os.path.join(save_dir, f"Residual_Plot_{sanitized_target_id}.svg"), bbox_inches='tight', format="svg")
|
|
457
462
|
plt.close()
|
|
458
463
|
|
|
459
464
|
# Create true vs predicted values plot
|
|
@@ -462,63 +467,66 @@ def evaluate_model_regression(model, model_name: str,
|
|
|
462
467
|
plt.plot([single_y_test.min(), single_y_test.max()],
|
|
463
468
|
[single_y_test.min(), single_y_test.max()],
|
|
464
469
|
'k--', lw=2)
|
|
465
|
-
plt.xlabel('True Values', fontsize=
|
|
466
|
-
plt.ylabel('Predictions', fontsize=
|
|
467
|
-
plt.title(f"{model_name} - True vs Predicted for {target_id}", fontsize=
|
|
470
|
+
plt.xlabel('True Values', fontsize=base_fontsize)
|
|
471
|
+
plt.ylabel('Predictions', fontsize=base_fontsize)
|
|
472
|
+
plt.title(f"{model_name} - True vs Predicted for {target_id}", fontsize=base_fontsize)
|
|
468
473
|
plt.grid(True)
|
|
469
|
-
plot_path = os.path.join(save_dir, f"Regression_Plot_{
|
|
474
|
+
plot_path = os.path.join(save_dir, f"Regression_Plot_{sanitized_target_id}.svg")
|
|
470
475
|
plt.savefig(plot_path, bbox_inches='tight', format="svg")
|
|
471
476
|
plt.close()
|
|
472
477
|
|
|
473
478
|
return y_pred
|
|
474
479
|
|
|
480
|
+
|
|
475
481
|
# Get SHAP values
|
|
476
|
-
def get_shap_values(
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
482
|
+
def get_shap_values(
|
|
483
|
+
model,
|
|
484
|
+
model_name: str,
|
|
485
|
+
save_dir: str,
|
|
486
|
+
features_to_explain: np.ndarray,
|
|
487
|
+
feature_names: list[str],
|
|
488
|
+
target_id: str,
|
|
489
|
+
task: Literal["classification", "regression"],
|
|
490
|
+
max_display_features: int = 10,
|
|
491
|
+
figsize: tuple = (16, 20),
|
|
492
|
+
base_fontsize: int = 38,
|
|
493
|
+
):
|
|
488
494
|
"""
|
|
489
495
|
Universal SHAP explainer for regression and classification.
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
496
|
+
* Use `X_train` (or a subsample of it) to see how the model explains the data it was trained on.
|
|
497
|
+
|
|
498
|
+
* Use `X_test` (or a hold-out set) to see how the model explains unseen data.
|
|
499
|
+
|
|
500
|
+
* Use the entire dataset to get the global view.
|
|
493
501
|
|
|
494
502
|
Parameters:
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
503
|
+
task: 'regression' or 'classification'
|
|
504
|
+
features_to_explain: Should match the model's training data format, including scaling.
|
|
505
|
+
save_dir: Directory to save visualizations
|
|
498
506
|
"""
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
preferred_styles = ['seaborn', 'seaborn-v0_8-darkgrid', 'seaborn-v0_8', 'default']
|
|
505
|
-
for style in preferred_styles:
|
|
507
|
+
sanitized_target_id = sanitize_filename(target_id)
|
|
508
|
+
|
|
509
|
+
def _apply_plot_style():
|
|
510
|
+
styles = ['seaborn', 'seaborn-v0_8-darkgrid', 'seaborn-v0_8', 'default']
|
|
511
|
+
for style in styles:
|
|
506
512
|
if style in plt.style.available or style == 'default':
|
|
507
513
|
plt.style.use(style)
|
|
508
514
|
break
|
|
509
|
-
|
|
515
|
+
|
|
516
|
+
def _configure_rcparams():
|
|
517
|
+
plt.rc('font', size=base_fontsize)
|
|
518
|
+
plt.rc('axes', titlesize=base_fontsize)
|
|
519
|
+
plt.rc('axes', labelsize=base_fontsize)
|
|
520
|
+
plt.rc('xtick', labelsize=base_fontsize)
|
|
521
|
+
plt.rc('ytick', labelsize=base_fontsize + 2)
|
|
522
|
+
plt.rc('legend', fontsize=base_fontsize)
|
|
523
|
+
plt.rc('figure', titlesize=base_fontsize)
|
|
524
|
+
|
|
525
|
+
def _create_shap_plot(shap_values, features, save_path: str, plot_type: str, title: str):
|
|
526
|
+
_apply_plot_style()
|
|
527
|
+
_configure_rcparams()
|
|
510
528
|
plt.figure(figsize=figsize)
|
|
511
|
-
|
|
512
|
-
#set rc parameters for better readability
|
|
513
|
-
plt.rc('font', size=label_fontsize)
|
|
514
|
-
plt.rc('axes', titlesize=title_fontsize)
|
|
515
|
-
plt.rc('axes', labelsize=label_fontsize)
|
|
516
|
-
plt.rc('xtick', labelsize=label_fontsize)
|
|
517
|
-
plt.rc('ytick', labelsize=label_fontsize)
|
|
518
|
-
plt.rc('legend', fontsize=label_fontsize)
|
|
519
|
-
plt.rc('figure', titlesize=title_fontsize)
|
|
520
|
-
|
|
521
|
-
# Create the SHAP plot
|
|
529
|
+
|
|
522
530
|
shap.summary_plot(
|
|
523
531
|
shap_values=shap_values,
|
|
524
532
|
features=features,
|
|
@@ -528,85 +536,76 @@ def get_shap_values(model, model_name: str,
|
|
|
528
536
|
plot_size=figsize,
|
|
529
537
|
max_display=max_display_features,
|
|
530
538
|
alpha=0.7,
|
|
531
|
-
color=
|
|
539
|
+
# color='viridis'
|
|
532
540
|
)
|
|
533
|
-
|
|
534
|
-
# Add professional styling
|
|
541
|
+
|
|
535
542
|
ax = plt.gca()
|
|
536
|
-
ax.set_xlabel("SHAP Value Impact", fontsize=
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
# Manually fix tick fonts
|
|
543
|
+
ax.set_xlabel("SHAP Value Impact", fontsize=base_fontsize + 2, weight='bold', labelpad=20)
|
|
544
|
+
plt.title(title, fontsize=base_fontsize + 2, pad=20, weight='bold')
|
|
545
|
+
|
|
541
546
|
for tick in ax.get_xticklabels():
|
|
542
|
-
tick.set_fontsize(
|
|
543
|
-
tick.set_rotation(
|
|
547
|
+
tick.set_fontsize(base_fontsize)
|
|
548
|
+
tick.set_rotation(30)
|
|
544
549
|
for tick in ax.get_yticklabels():
|
|
545
|
-
tick.set_fontsize(
|
|
550
|
+
tick.set_fontsize(base_fontsize + 2)
|
|
546
551
|
|
|
547
|
-
# Handle colorbar for dot plots
|
|
548
552
|
if plot_type == "dot":
|
|
549
553
|
cb = plt.gcf().axes[-1]
|
|
550
|
-
# cb.set_ylabel("Feature Value", size=label_fontsize)
|
|
551
554
|
cb.set_ylabel("", size=1)
|
|
552
|
-
cb.tick_params(labelsize=
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
plt.savefig(
|
|
556
|
-
full_save_path,
|
|
557
|
-
bbox_inches='tight',
|
|
558
|
-
facecolor='white',
|
|
559
|
-
format="svg"
|
|
560
|
-
)
|
|
555
|
+
cb.tick_params(labelsize=base_fontsize - 2)
|
|
556
|
+
|
|
557
|
+
plt.savefig(save_path, bbox_inches='tight', facecolor='white', format="svg")
|
|
561
558
|
plt.close()
|
|
562
|
-
rcdefaults()
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
# Handle different model types
|
|
569
|
-
if task == 'classification':
|
|
570
|
-
# Determine if multiclass
|
|
571
|
-
try:
|
|
572
|
-
is_multiclass = len(model.classes_) > 2
|
|
573
|
-
class_names = model.classes_
|
|
574
|
-
except AttributeError:
|
|
575
|
-
is_multiclass = isinstance(shap_values, list) and len(shap_values) > 1
|
|
576
|
-
class_names = list(range(len(shap_values))) if is_multiclass else [0, 1]
|
|
577
|
-
|
|
559
|
+
rcdefaults()
|
|
560
|
+
|
|
561
|
+
def _plot_for_classification(shap_values, class_names):
|
|
562
|
+
is_multiclass = isinstance(shap_values, list) and len(shap_values) > 1
|
|
563
|
+
|
|
578
564
|
if is_multiclass:
|
|
579
|
-
for
|
|
565
|
+
for class_shap, class_name in zip(shap_values, class_names):
|
|
566
|
+
for plot_type in ["bar", "dot"]:
|
|
567
|
+
_create_shap_plot(
|
|
568
|
+
shap_values=class_shap,
|
|
569
|
+
features=features_to_explain,
|
|
570
|
+
save_path=os.path.join(save_dir, f"SHAP_{sanitized_target_id}_Class{class_name}_{plot_type}.svg"),
|
|
571
|
+
plot_type=plot_type,
|
|
572
|
+
title=f"{model_name} - {target_id} (Class {class_name})"
|
|
573
|
+
)
|
|
574
|
+
else:
|
|
575
|
+
values = shap_values[1] if isinstance(shap_values, list) else shap_values
|
|
576
|
+
for plot_type in ["bar", "dot"]:
|
|
580
577
|
_create_shap_plot(
|
|
581
|
-
shap_values=
|
|
578
|
+
shap_values=values,
|
|
582
579
|
features=features_to_explain,
|
|
583
|
-
|
|
584
|
-
full_save_path=os.path.join(save_dir, f"SHAP_{target_id}_Class{class_name}.svg"),
|
|
580
|
+
save_path=os.path.join(save_dir, f"SHAP_{sanitized_target_id}_{plot_type}.svg"),
|
|
585
581
|
plot_type=plot_type,
|
|
586
|
-
title=f"{model_name} - {target_id}
|
|
582
|
+
title=f"{model_name} - {target_id}"
|
|
587
583
|
)
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
584
|
+
|
|
585
|
+
def _plot_for_regression(shap_values):
|
|
586
|
+
for plot_type in ["bar", "dot"]:
|
|
591
587
|
_create_shap_plot(
|
|
592
|
-
shap_values=
|
|
588
|
+
shap_values=shap_values,
|
|
593
589
|
features=features_to_explain,
|
|
594
|
-
|
|
595
|
-
full_save_path=os.path.join(save_dir, f"SHAP_{target_id}.svg"),
|
|
590
|
+
save_path=os.path.join(save_dir, f"SHAP_{sanitized_target_id}_{plot_type}.svg"),
|
|
596
591
|
plot_type=plot_type,
|
|
597
592
|
title=f"{model_name} - {target_id}"
|
|
598
593
|
)
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
594
|
+
#START_O
|
|
595
|
+
|
|
596
|
+
explainer = shap.TreeExplainer(model)
|
|
597
|
+
shap_values = explainer.shap_values(features_to_explain)
|
|
598
|
+
|
|
599
|
+
if task == 'classification':
|
|
600
|
+
try:
|
|
601
|
+
class_names = model.classes_ if hasattr(model, 'classes_') else list(range(len(shap_values)))
|
|
602
|
+
except Exception:
|
|
603
|
+
class_names = list(range(len(shap_values)))
|
|
604
|
+
_plot_for_classification(shap_values, class_names)
|
|
605
|
+
else:
|
|
606
|
+
_plot_for_regression(shap_values)
|
|
607
|
+
|
|
608
|
+
|
|
610
609
|
# TRAIN TEST PIPELINE
|
|
611
610
|
def train_test_pipeline(model, model_name: str, dataset_id: str, task: Literal["classification", "regression"],
|
|
612
611
|
train_features: np.ndarray, train_target: np.ndarray,
|
|
@@ -653,7 +652,7 @@ def train_test_pipeline(model, model_name: str, dataset_id: str, task: Literal["
|
|
|
653
652
|
return trained_model, y_pred
|
|
654
653
|
|
|
655
654
|
###### 5. Execution ######
|
|
656
|
-
def
|
|
655
|
+
def run_ensemble_pipeline(datasets_dir: str, save_dir: str, target_columns: list[str], task: Literal["classification", "regression"],
|
|
657
656
|
resample_strategy: Literal[r"ADASYN", r'SMOTE', r'RANDOM', r'UNDERSAMPLE', None]=None, scaler: Literal["standard", "minmax", "maxabs"]="minmax", save_model: bool=False,
|
|
658
657
|
test_size: float=0.2, debug:bool=False, L1_regularization: float=0.5, L2_regularization: float=0.5, learning_rate: float=0.005, random_state: int=101):
|
|
659
658
|
#Check paths
|
|
@@ -672,15 +671,15 @@ def run_pipeline(datasets_dir: str, save_dir: str, target_columns: list[str], ta
|
|
|
672
671
|
#Train models
|
|
673
672
|
for model_name, model in models_dict.items():
|
|
674
673
|
train_test_pipeline(model=model, model_name=model_name, dataset_id=dataframe_name, task=task,
|
|
675
|
-
train_features=X_train, train_target=y_train,
|
|
674
|
+
train_features=X_train, train_target=y_train, # type: ignore
|
|
676
675
|
test_features=X_test, test_target=y_test,
|
|
677
676
|
feature_names=feature_names,target_id=target_name, scaler_object=scaler_object,
|
|
678
677
|
debug=debug, save_dir=save_dir, save_model=save_model)
|
|
679
|
-
print("\
|
|
678
|
+
print("\n✅ Training and evaluation complete.")
|
|
680
679
|
|
|
681
680
|
|
|
682
681
|
def _check_paths(datasets_dir: str, save_dir:str):
|
|
683
682
|
if not os.path.isdir(save_dir):
|
|
684
|
-
os.makedirs(save_dir)
|
|
683
|
+
os.makedirs(save_dir)
|
|
685
684
|
if not os.path.isdir(datasets_dir):
|
|
686
|
-
raise IOError(f"Datasets directory '{datasets_dir}' not found
|
|
685
|
+
raise IOError(f"Datasets directory '{datasets_dir}' not found.")
|
ml_tools/handle_excel.py
CHANGED
|
@@ -2,6 +2,16 @@ import os
|
|
|
2
2
|
from openpyxl import load_workbook, Workbook
|
|
3
3
|
import pandas as pd
|
|
4
4
|
from typing import List, Optional
|
|
5
|
+
from utilities import _script_info, sanitize_filename
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"unmerge_and_split_excel",
|
|
10
|
+
"unmerge_and_split_from_directory",
|
|
11
|
+
"validate_excel_schema",
|
|
12
|
+
"vertical_merge_transform_excel",
|
|
13
|
+
"horizontal_merge_transform_excel"
|
|
14
|
+
]
|
|
5
15
|
|
|
6
16
|
|
|
7
17
|
def unmerge_and_split_excel(filepath: str) -> None:
|
|
@@ -25,12 +35,12 @@ def unmerge_and_split_excel(filepath: str) -> None:
|
|
|
25
35
|
ws = wb[sheet_name]
|
|
26
36
|
new_wb = Workbook()
|
|
27
37
|
new_ws = new_wb.active
|
|
28
|
-
new_ws.title = sheet_name
|
|
38
|
+
new_ws.title = sheet_name # type: ignore
|
|
29
39
|
|
|
30
40
|
# Copy all cell values
|
|
31
41
|
for row in ws.iter_rows():
|
|
32
42
|
for cell in row:
|
|
33
|
-
new_ws.cell(row=cell.row, column=cell.column, value=cell.value)
|
|
43
|
+
new_ws.cell(row=cell.row, column=cell.column, value=cell.value) # type: ignore
|
|
34
44
|
|
|
35
45
|
# Fill and unmerge merged regions
|
|
36
46
|
for merged_range in list(ws.merged_cells.ranges):
|
|
@@ -41,10 +51,10 @@ def unmerge_and_split_excel(filepath: str) -> None:
|
|
|
41
51
|
value = ws.cell(row=min_row, column=min_col).value
|
|
42
52
|
for row in range(min_row, max_row + 1):
|
|
43
53
|
for col in range(min_col, max_col + 1):
|
|
44
|
-
new_ws.cell(row=row, column=col, value=value)
|
|
54
|
+
new_ws.cell(row=row, column=col, value=value) # type: ignore
|
|
45
55
|
|
|
46
56
|
# Construct flat output file name
|
|
47
|
-
sanitized_sheet_name = sheet_name
|
|
57
|
+
sanitized_sheet_name = sanitize_filename(sheet_name)
|
|
48
58
|
output_filename = f"{base_name}_{sanitized_sheet_name}.xlsx"
|
|
49
59
|
output_path = os.path.join(base_dir, output_filename)
|
|
50
60
|
new_wb.save(output_path)
|
|
@@ -85,12 +95,12 @@ def unmerge_and_split_from_directory(input_dir: str, output_dir: str) -> None:
|
|
|
85
95
|
ws = wb[sheet_name]
|
|
86
96
|
new_wb = Workbook()
|
|
87
97
|
new_ws = new_wb.active
|
|
88
|
-
new_ws.title = sheet_name
|
|
98
|
+
new_ws.title = sheet_name # type: ignore
|
|
89
99
|
|
|
90
100
|
# Copy all cell values
|
|
91
101
|
for row in ws.iter_rows():
|
|
92
102
|
for cell in row:
|
|
93
|
-
new_ws.cell(row=cell.row, column=cell.column, value=cell.value)
|
|
103
|
+
new_ws.cell(row=cell.row, column=cell.column, value=cell.value) # type: ignore
|
|
94
104
|
|
|
95
105
|
# Fill and unmerge merged regions
|
|
96
106
|
for merged_range in list(ws.merged_cells.ranges):
|
|
@@ -101,10 +111,10 @@ def unmerge_and_split_from_directory(input_dir: str, output_dir: str) -> None:
|
|
|
101
111
|
value = ws.cell(row=min_row, column=min_col).value
|
|
102
112
|
for row in range(min_row, max_row + 1):
|
|
103
113
|
for col in range(min_col, max_col + 1):
|
|
104
|
-
new_ws.cell(row=row, column=col, value=value)
|
|
114
|
+
new_ws.cell(row=row, column=col, value=value) # type: ignore
|
|
105
115
|
|
|
106
116
|
# Construct flat output file name
|
|
107
|
-
sanitized_sheet_name = sheet_name
|
|
117
|
+
sanitized_sheet_name = sanitize_filename(sheet_name)
|
|
108
118
|
output_filename = f"{base_name}_{sanitized_sheet_name}.xlsx"
|
|
109
119
|
output_path = os.path.join(output_dir, output_filename)
|
|
110
120
|
new_wb.save(output_path)
|
|
@@ -151,7 +161,7 @@ def validate_excel_schema(
|
|
|
151
161
|
wb = load_workbook(file_path, read_only=True)
|
|
152
162
|
ws = wb.active # Only check the first worksheet
|
|
153
163
|
|
|
154
|
-
header = [cell.value for cell in next(ws.iter_rows(max_row=1))]
|
|
164
|
+
header = [cell.value for cell in next(ws.iter_rows(max_row=1))] # type: ignore
|
|
155
165
|
|
|
156
166
|
if strict:
|
|
157
167
|
if header != expected_columns:
|
|
@@ -202,6 +212,11 @@ def vertical_merge_transform_excel(
|
|
|
202
212
|
|
|
203
213
|
if not excel_files:
|
|
204
214
|
raise ValueError("No Excel files found in the target directory.")
|
|
215
|
+
|
|
216
|
+
# sanitize filename
|
|
217
|
+
csv_filename = sanitize_filename(csv_filename)
|
|
218
|
+
# make directory
|
|
219
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
205
220
|
|
|
206
221
|
csv_filename = csv_filename if csv_filename.endswith('.csv') else f"{csv_filename}.csv"
|
|
207
222
|
csv_path = os.path.join(output_dir, csv_filename)
|
|
@@ -260,6 +275,11 @@ def horizontal_merge_transform_excel(
|
|
|
260
275
|
excel_files = [f for f in raw_excel_files if not f.startswith('~')] # Exclude temporary files
|
|
261
276
|
if not excel_files:
|
|
262
277
|
raise ValueError("No Excel files found in the target directory.")
|
|
278
|
+
|
|
279
|
+
# sanitize filename
|
|
280
|
+
csv_filename = sanitize_filename(csv_filename)
|
|
281
|
+
# make directory
|
|
282
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
263
283
|
|
|
264
284
|
csv_filename = csv_filename if csv_filename.endswith('.csv') else f"{csv_filename}.csv"
|
|
265
285
|
csv_path = os.path.join(output_dir, csv_filename)
|
|
@@ -308,3 +328,6 @@ def horizontal_merge_transform_excel(
|
|
|
308
328
|
if duplicate_columns:
|
|
309
329
|
print(f"⚠️ Duplicate columns: {duplicate_columns}")
|
|
310
330
|
|
|
331
|
+
|
|
332
|
+
def info():
|
|
333
|
+
_script_info(__all__)
|
ml_tools/logger.py
CHANGED
|
@@ -5,7 +5,12 @@ import pandas as pd
|
|
|
5
5
|
from openpyxl.styles import Font, PatternFill
|
|
6
6
|
import traceback
|
|
7
7
|
import json
|
|
8
|
-
from ml_tools.utilities import sanitize_filename
|
|
8
|
+
from ml_tools.utilities import sanitize_filename, _script_info
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"custom_logger"
|
|
13
|
+
]
|
|
9
14
|
|
|
10
15
|
|
|
11
16
|
def custom_logger(
|
|
@@ -143,3 +148,7 @@ def _log_exception_to_log(exc: BaseException, path: str) -> None:
|
|
|
143
148
|
def _log_dict_to_json(data: Dict[Any, Any], path: str) -> None:
|
|
144
149
|
with open(path, 'w', encoding='utf-8') as f:
|
|
145
150
|
json.dump(data, f, indent=4, ensure_ascii=False)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def info():
|
|
154
|
+
_script_info(__all__)
|