dragon-ml-toolbox 5.3.0__py3-none-any.whl → 6.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -1,7 +1,7 @@
1
1
  from ._script_info import _script_info
2
2
  from ._logger import _LOGGER
3
3
  from .path_manager import make_fullpath, list_files_by_extension
4
- from .keys import ModelSaveKeys
4
+ from .keys import EnsembleKeys
5
5
 
6
6
  from typing import Union, Literal, Dict, Any, Optional, List
7
7
  from pathlib import Path
@@ -49,9 +49,9 @@ class InferenceHandler:
49
49
  verbose=self.verbose,
50
50
  raise_on_error=True) # type: ignore
51
51
 
52
- model: Any = full_object[ModelSaveKeys.MODEL]
53
- target_name: str = full_object[ModelSaveKeys.TARGET]
54
- feature_names_list: List[str] = full_object[ModelSaveKeys.FEATURES]
52
+ model: Any = full_object[EnsembleKeys.MODEL]
53
+ target_name: str = full_object[EnsembleKeys.TARGET]
54
+ feature_names_list: List[str] = full_object[EnsembleKeys.FEATURES]
55
55
 
56
56
  # Check that feature names match
57
57
  if self._feature_names is None:
@@ -102,8 +102,8 @@ class InferenceHandler:
102
102
  else: # Classification
103
103
  label = model.predict(features)[0]
104
104
  probabilities = model.predict_proba(features)[0]
105
- results[target_name] = {ModelSaveKeys.CLASSIFICATION_LABEL: label,
106
- ModelSaveKeys.CLASSIFICATION_PROBABILITIES: probabilities}
105
+ results[target_name] = {EnsembleKeys.CLASSIFICATION_LABEL: label,
106
+ EnsembleKeys.CLASSIFICATION_PROBABILITIES: probabilities}
107
107
 
108
108
  if self.verbose:
109
109
  _LOGGER.info("✅ Inference process complete.")
@@ -170,15 +170,15 @@ def model_report(
170
170
  # --- 2. Deserialize and Extract Info ---
171
171
  try:
172
172
  full_object: dict = _deserialize_object(model_p) # type: ignore
173
- model = full_object[ModelSaveKeys.MODEL]
174
- target = full_object[ModelSaveKeys.TARGET]
175
- features = full_object[ModelSaveKeys.FEATURES]
173
+ model = full_object[EnsembleKeys.MODEL]
174
+ target = full_object[EnsembleKeys.TARGET]
175
+ features = full_object[EnsembleKeys.FEATURES]
176
176
  except FileNotFoundError:
177
177
  _LOGGER.error(f"❌ Model file not found at '{model_p}'")
178
178
  raise
179
179
  except (KeyError, TypeError) as e:
180
180
  _LOGGER.error(
181
- f"❌ The serialized object is missing required keys '{ModelSaveKeys.MODEL}', '{ModelSaveKeys.TARGET}', '{ModelSaveKeys.FEATURES}'"
181
+ f"❌ The serialized object is missing required keys '{EnsembleKeys.MODEL}', '{EnsembleKeys.TARGET}', '{EnsembleKeys.FEATURES}'"
182
182
  )
183
183
  raise e
184
184
 
@@ -1,12 +1,8 @@
1
1
  import pandas as pd
2
2
  import numpy as np
3
- import seaborn # Use plot styling
4
- import matplotlib.pyplot as plt
5
- from matplotlib.colors import Colormap
6
- from matplotlib import rcdefaults
7
3
 
8
4
  from pathlib import Path
9
- from typing import Literal, Union, Optional, Iterator, Tuple
5
+ from typing import Literal, Union, Optional
10
6
 
11
7
  from imblearn.over_sampling import ADASYN, SMOTE, RandomOverSampler
12
8
  from imblearn.under_sampling import RandomUnderSampler
@@ -15,14 +11,20 @@ import xgboost as xgb
15
11
  import lightgbm as lgb
16
12
 
17
13
  from sklearn.model_selection import train_test_split
18
- from sklearn.metrics import accuracy_score, classification_report, ConfusionMatrixDisplay, mean_absolute_error, mean_squared_error, r2_score, roc_curve, roc_auc_score
19
- import shap
14
+ from sklearn.base import clone
20
15
 
21
- from .utilities import yield_dataframes_from_dir, serialize_object
16
+ from .utilities import yield_dataframes_from_dir, serialize_object, train_dataset_yielder
22
17
  from .path_manager import sanitize_filename, make_fullpath
23
18
  from ._script_info import _script_info
24
- from .keys import ModelSaveKeys
19
+ from .keys import EnsembleKeys
25
20
  from ._logger import _LOGGER
21
+ from .ensemble_evaluation import (evaluate_model_classification,
22
+ plot_roc_curve,
23
+ plot_precision_recall_curve,
24
+ plot_calibration_curve,
25
+ evaluate_model_regression,
26
+ get_shap_values,
27
+ plot_learning_curves)
26
28
 
27
29
  import warnings # Ignore warnings
28
30
  warnings.filterwarnings('ignore', category=DeprecationWarning)
@@ -31,14 +33,9 @@ warnings.filterwarnings('ignore', category=UserWarning)
31
33
 
32
34
 
33
35
  __all__ = [
34
- "dataset_yielder",
35
36
  "RegressionTreeModels",
36
37
  "ClassificationTreeModels",
37
38
  "dataset_pipeline",
38
- "evaluate_model_classification",
39
- "plot_roc_curve",
40
- "evaluate_model_regression",
41
- "get_shap_values",
42
39
  "train_test_pipeline",
43
40
  "run_ensemble_pipeline",
44
41
  ]
@@ -48,34 +45,7 @@ HandleImbalanceStrategy = Literal[
48
45
  "ADASYN", "SMOTE", "RAND_OVERSAMPLE", "RAND_UNDERSAMPLE", "by_model", None
49
46
  ]
50
47
 
51
- TaskType = Literal[
52
- "classification", "regression"
53
- ]
54
-
55
- ###### 1. Dataset Loader ######
56
- def dataset_yielder(
57
- df: pd.DataFrame,
58
- target_cols: list[str]
59
- ) -> Iterator[Tuple[pd.DataFrame, pd.Series, list[str], str]]:
60
- """
61
- Yields one tuple at a time:
62
- (features_dataframe, target_series, feature_names, target_name)
63
-
64
- Skips any target columns not found in the DataFrame.
65
- """
66
- # Determine which target columns actually exist in the DataFrame
67
- valid_targets = [col for col in target_cols if col in df.columns]
68
-
69
- # Features = all columns excluding valid target columns
70
- df_features = df.drop(columns=valid_targets)
71
- feature_names = df_features.columns.to_list()
72
-
73
- for target_col in valid_targets:
74
- df_target = df[target_col]
75
- yield (df_features, df_target, feature_names, target_col)
76
-
77
-
78
- ###### 2. Initialize Models ######
48
+ ###### 1. Initialize Models ######
79
49
  class RegressionTreeModels:
80
50
  """
81
51
  A factory class for creating and configuring multiple gradient boosting regression models
@@ -345,7 +315,7 @@ class ClassificationTreeModels:
345
315
  return f"{self.__class__.__name__}(n_estimators={self.n_estimators}, max_depth={self.max_depth}, lr={self.lr}, L1={self.L1}, L2={self.L2}"
346
316
 
347
317
 
348
- ###### 3. Process Dataset ######
318
+ ###### 2. Process Dataset ######
349
319
  # function to split data into train and test
350
320
  def _split_data(features, target, test_size, random_state, task):
351
321
  X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=test_size, random_state=random_state,
@@ -375,7 +345,7 @@ def _resample(X_train: np.ndarray, y_train: pd.Series,
375
345
  return X_res, y_res
376
346
 
377
347
  # DATASET PIPELINE
378
- def dataset_pipeline(df_features: pd.DataFrame, df_target: pd.Series, task: TaskType,
348
+ def dataset_pipeline(df_features: pd.DataFrame, df_target: pd.Series, task: Literal["classification", "regression"],
379
349
  resample_strategy: HandleImbalanceStrategy,
380
350
  test_size: float=0.2, debug: bool=False, random_state: int=101):
381
351
  '''
@@ -412,7 +382,7 @@ def dataset_pipeline(df_features: pd.DataFrame, df_target: pd.Series, task: Task
412
382
 
413
383
  return X_train_oversampled, y_train_oversampled, X_test, y_test
414
384
 
415
- ###### 4. Train and Evaluation ######
385
+ ###### 3. Train and Evaluation ######
416
386
  # Trainer function
417
387
  def _train_model(model, train_features, train_target):
418
388
  model.fit(train_features, train_target)
@@ -435,381 +405,26 @@ def _save_model(trained_model, model_name: str, target_name:str, feature_names:
435
405
  #Sanitize filenames to save
436
406
  sanitized_target_name = sanitize_filename(target_name)
437
407
  filename = f"{model_name}_{sanitized_target_name}"
438
- to_save = {ModelSaveKeys.MODEL: trained_model,
439
- ModelSaveKeys.FEATURES: feature_names,
440
- ModelSaveKeys.TARGET: target_name}
408
+ to_save = {EnsembleKeys.MODEL: trained_model,
409
+ EnsembleKeys.FEATURES: feature_names,
410
+ EnsembleKeys.TARGET: target_name}
441
411
 
442
412
  serialize_object(obj=to_save, save_dir=save_directory, filename=filename, verbose=False, raise_on_error=True)
443
413
 
444
- # function to evaluate the model and save metrics (Classification)
445
- def evaluate_model_classification(
446
- model,
447
- model_name: str,
448
- save_dir: Union[str,Path],
449
- x_test_scaled: np.ndarray,
450
- single_y_test: np.ndarray,
451
- target_name: str,
452
- figsize: tuple = (10, 8),
453
- base_fontsize: int = 24,
454
- cmap: Colormap = plt.cm.Blues # type: ignore
455
- ) -> np.ndarray:
456
- """
457
- Evaluates a classification model, saves the classification report and confusion matrix plot.
458
-
459
- Parameters:
460
- model: Trained classifier with .predict() method
461
- model_name: Identifier for the model
462
- save_dir: Directory where results are saved
463
- x_test_scaled: Feature matrix for test set
464
- single_y_test: True targets
465
- target_name: Target name
466
- figsize: Size of the confusion matrix figure (width, height)
467
- fontsize: Font size used for title, axis labels and ticks
468
- cmap: Color map for the confusion matrix. Examples include:
469
- - plt.cm.Blues (default)
470
- - plt.cm.Greens
471
- - plt.cm.Oranges
472
- - plt.cm.Purples
473
- - plt.cm.Reds
474
- - plt.cm.cividis
475
- - plt.cm.inferno
476
-
477
- Returns:
478
- y_pred: Predicted class labels
479
- """
480
- save_path = make_fullpath(save_dir, make=True)
481
-
482
- y_pred = model.predict(x_test_scaled)
483
- accuracy = accuracy_score(single_y_test, y_pred)
484
-
485
- report = classification_report(
486
- single_y_test,
487
- y_pred,
488
- target_names=["Negative", "Positive"],
489
- output_dict=False
490
- )
491
-
492
- # Save text report
493
- sanitized_target_name = sanitize_filename(target_name)
494
- report_path = save_path / f"Classification_Report_{sanitized_target_name}.txt"
495
- with open(report_path, "w") as f:
496
- f.write(f"{model_name} - {target_name}\t\tAccuracy: {accuracy:.2f}\n")
497
- f.write("Classification Report:\n")
498
- f.write(report) # type: ignore
499
-
500
- # Create confusion matrix
501
- fig, ax = plt.subplots(figsize=figsize)
502
- disp = ConfusionMatrixDisplay.from_predictions(
503
- y_true=single_y_test,
504
- y_pred=y_pred,
505
- display_labels=["Negative", "Positive"],
506
- cmap=cmap,
507
- normalize="true",
508
- ax=ax
509
- )
510
-
511
- ax.set_title(f"{model_name} - {target_name}", fontsize=base_fontsize)
512
- ax.tick_params(axis='both', labelsize=base_fontsize)
513
- ax.set_xlabel("Predicted label", fontsize=base_fontsize)
514
- ax.set_ylabel("True label", fontsize=base_fontsize)
515
-
516
- # Turn off gridlines
517
- ax.grid(False)
518
-
519
- # Manually update font size of cell texts
520
- for text in ax.texts:
521
- text.set_fontsize(base_fontsize+4)
522
-
523
- fig.tight_layout()
524
- fig_path = save_path / f"Confusion_Matrix_{sanitized_target_name}.svg"
525
- fig.savefig(fig_path, format="svg", bbox_inches="tight") # type: ignore
526
- plt.close(fig)
527
-
528
- return y_pred
529
-
530
- #Function to save ROC and ROC AUC (Classification)
531
- def plot_roc_curve(
532
- true_labels: np.ndarray,
533
- probabilities_or_model: Union[np.ndarray, xgb.XGBClassifier, lgb.LGBMClassifier, object],
534
- model_name: str,
535
- target_name: str,
536
- save_directory: Union[str,Path],
537
- color: str = "darkorange",
538
- figure_size: tuple = (10, 10),
539
- linewidth: int = 2,
540
- base_fontsize: int = 24,
541
- input_features: Optional[np.ndarray] = None,
542
- ) -> plt.Figure: # type: ignore
543
- """
544
- Plots the ROC curve and computes AUC for binary classification. Positive class is assumed to be in the second column of the probabilities array.
545
-
546
- Parameters:
547
- true_labels: np.ndarray of shape (n_samples,), ground truth binary labels (0 or 1).
548
- probabilities_or_model: either predicted probabilities (ndarray), or a trained model with attribute `.predict_proba()`.
549
- target_name: str, Target name.
550
- save_directory: str or Path, path to directory where figure is saved.
551
- color: color of the ROC curve. Accepts any valid Matplotlib color specification. Examples:
552
- - Named colors: "darkorange", "blue", "red", "green", "black"
553
- - Hex codes: "#1f77b4", "#ff7f0e"
554
- - RGB tuples: (0.2, 0.4, 0.6)
555
- - Colormap value: plt.cm.viridis(0.6)
556
- figure_size: Tuple for figure size (width, height).
557
- linewidth: int, width of the plotted ROC line.
558
- title_fontsize: int, font size of the title.
559
- label_fontsize: int, font size for axes labels.
560
- input_features: np.ndarray of shape (n_samples, n_features), required if a model is passed.
561
-
562
- Returns:
563
- fig: matplotlib Figure object
564
- """
565
-
566
- # Determine predicted probabilities
567
- if isinstance(probabilities_or_model, np.ndarray):
568
- # Input is already probabilities
569
- if probabilities_or_model.ndim == 2: # type: ignore
570
- y_score = probabilities_or_model[:, 1] # type: ignore
571
- else:
572
- y_score = probabilities_or_model
573
-
574
- elif hasattr(probabilities_or_model, "predict_proba"):
575
- if input_features is None:
576
- raise ValueError("input_features must be provided when using a classifier.")
577
-
578
- try:
579
- classes = probabilities_or_model.classes_ # type: ignore
580
- positive_class_index = list(classes).index(1)
581
- except (AttributeError, ValueError):
582
- positive_class_index = 1
583
414
 
584
- y_score = probabilities_or_model.predict_proba(input_features)[:, positive_class_index] # type: ignore
585
-
586
- else:
587
- raise TypeError("Unsupported type for 'probabilities_or_model'. Must be a NumPy array or a model with support for '.predict_proba()'.")
588
-
589
- # ROC and AUC
590
- fpr, tpr, _ = roc_curve(true_labels, y_score)
591
- auc_score = roc_auc_score(true_labels, y_score)
592
-
593
- # Plot
594
- fig, ax = plt.subplots(figsize=figure_size)
595
- ax.plot(fpr, tpr, color=color, lw=linewidth, label=f"AUC = {auc_score:.2f}")
596
- ax.plot([0, 1], [0, 1], color="gray", linestyle="--", lw=1)
597
-
598
- ax.set_title(f"{model_name} - {target_name}", fontsize=base_fontsize)
599
- ax.set_xlabel("False Positive Rate", fontsize=base_fontsize)
600
- ax.set_ylabel("True Positive Rate", fontsize=base_fontsize)
601
- ax.tick_params(axis='both', labelsize=base_fontsize)
602
- ax.legend(loc="lower right", fontsize=base_fontsize)
603
- ax.grid(True)
604
-
605
- # Save figure
606
- save_path = make_fullpath(save_directory, make=True)
607
- sanitized_target_name = sanitize_filename(target_name)
608
- full_save_path = save_path / f"ROC_{sanitized_target_name}.svg"
609
- fig.savefig(full_save_path, bbox_inches="tight", format="svg") # type: ignore
610
-
611
- return fig
612
-
613
-
614
- # function to evaluate the model and save metrics (Regression)
615
- def evaluate_model_regression(model, model_name: str,
616
- save_dir: Union[str,Path],
617
- x_test_scaled: np.ndarray, single_y_test: np.ndarray,
618
- target_name: str,
619
- figure_size: tuple = (12, 8),
620
- alpha_transparency: float = 0.5,
621
- base_fontsize: int = 24):
622
- # Generate predictions
623
- y_pred = model.predict(x_test_scaled)
624
-
625
- # Calculate regression metrics
626
- mae = mean_absolute_error(single_y_test, y_pred)
627
- mse = mean_squared_error(single_y_test, y_pred)
628
- rmse = np.sqrt(mse)
629
- r2 = r2_score(single_y_test, y_pred)
630
-
631
- # Create formatted report
632
- sanitized_target_name = sanitize_filename(target_name)
633
- save_path = make_fullpath(save_dir, make=True)
634
- report_path = save_path / f"Regression_Report_{sanitized_target_name}.txt"
635
- with open(report_path, "w") as f:
636
- f.write(f"{model_name} - Regression Performance for '{target_name}'\n\n")
637
- f.write(f"Mean Absolute Error (MAE): {mae:.4f}\n")
638
- f.write(f"Mean Squared Error (MSE): {mse:.4f}\n")
639
- f.write(f"Root Mean Squared Error (RMSE): {rmse:.4f}\n")
640
- f.write(f"R² Score: {r2:.4f}\n")
641
-
642
- # Generate and save residual plot
643
- residuals = single_y_test - y_pred
644
- plt.figure(figsize=figure_size)
645
- plt.scatter(y_pred, residuals, alpha=alpha_transparency)
646
- plt.axhline(0, color='red', linestyle='--')
647
- plt.xlabel("Predicted Values", fontsize=base_fontsize)
648
- plt.ylabel("Residuals", fontsize=base_fontsize)
649
- plt.title(f"{model_name} - Residual Plot for {target_name}", fontsize=base_fontsize)
650
- plt.grid(True)
651
- plt.tight_layout()
652
- residual_path = save_path / f"Residual_Plot_{sanitized_target_name}.svg"
653
- plt.savefig(residual_path, bbox_inches='tight', format="svg")
654
- plt.close()
655
-
656
- # Create true vs predicted values plot
657
- plt.figure(figsize=figure_size)
658
- plt.scatter(single_y_test, y_pred, alpha=alpha_transparency)
659
- plt.plot([single_y_test.min(), single_y_test.max()],
660
- [single_y_test.min(), single_y_test.max()],
661
- 'k--', lw=2)
662
- plt.xlabel('True Values', fontsize=base_fontsize)
663
- plt.ylabel('Predictions', fontsize=base_fontsize)
664
- plt.title(f"{model_name} - True vs Predicted for {target_name}", fontsize=base_fontsize)
665
- plt.grid(True)
666
- plot_path = save_path / f"Regression_Plot_{sanitized_target_name}.svg"
667
- plt.savefig(plot_path, bbox_inches='tight', format="svg")
668
- plt.close()
669
-
670
- return y_pred
671
-
672
-
673
- # Get SHAP values
674
- def get_shap_values(
675
- model,
676
- model_name: str,
677
- save_dir: Union[str, Path],
678
- features_to_explain: np.ndarray,
679
- feature_names: list[str],
680
- target_name: str,
681
- task: Literal["classification", "regression"],
682
- max_display_features: int = 10,
683
- figsize: tuple = (16, 20),
684
- base_fontsize: int = 38,
685
- ):
686
- """
687
- Universal SHAP explainer for regression and classification.
688
- * Use `X_train` (or a subsample of it) to see how the model explains the data it was trained on.
689
-
690
- * Use `X_test` (or a hold-out set) to see how the model explains unseen data.
691
-
692
- * Use the entire dataset to get the global view.
693
-
694
- Parameters:
695
- task: 'regression' or 'classification'.
696
- features_to_explain: Should match the model's training data format, including scaling.
697
- save_dir: Directory to save visualizations.
698
- """
699
- sanitized_target_name = sanitize_filename(target_name)
700
- global_save_path = make_fullpath(save_dir, make=True)
701
-
702
- def _apply_plot_style():
703
- styles = ['seaborn', 'seaborn-v0_8-darkgrid', 'seaborn-v0_8', 'default']
704
- for style in styles:
705
- if style in plt.style.available or style == 'default':
706
- plt.style.use(style)
707
- break
708
-
709
- def _configure_rcparams():
710
- plt.rc('font', size=base_fontsize)
711
- plt.rc('axes', titlesize=base_fontsize)
712
- plt.rc('axes', labelsize=base_fontsize)
713
- plt.rc('xtick', labelsize=base_fontsize)
714
- plt.rc('ytick', labelsize=base_fontsize + 2)
715
- plt.rc('legend', fontsize=base_fontsize)
716
- plt.rc('figure', titlesize=base_fontsize)
717
-
718
- def _create_shap_plot(shap_values, features, save_path: Path, plot_type: str, title: str):
719
- _apply_plot_style()
720
- _configure_rcparams()
721
- plt.figure(figsize=figsize)
722
-
723
- shap.summary_plot(
724
- shap_values=shap_values,
725
- features=features,
726
- feature_names=feature_names,
727
- plot_type=plot_type,
728
- show=False,
729
- plot_size=figsize,
730
- max_display=max_display_features,
731
- alpha=0.7,
732
- # color='viridis'
733
- )
734
-
735
- ax = plt.gca()
736
- ax.set_xlabel("SHAP Value Impact", fontsize=base_fontsize + 2, weight='bold', labelpad=20)
737
- plt.title(title, fontsize=base_fontsize + 2, pad=20, weight='bold')
738
-
739
- for tick in ax.get_xticklabels():
740
- tick.set_fontsize(base_fontsize)
741
- tick.set_rotation(30)
742
- for tick in ax.get_yticklabels():
743
- tick.set_fontsize(base_fontsize + 2)
744
-
745
- if plot_type == "dot":
746
- cb = plt.gcf().axes[-1]
747
- cb.set_ylabel("", size=1)
748
- cb.tick_params(labelsize=base_fontsize - 2)
749
-
750
- plt.savefig(save_path, bbox_inches='tight', facecolor='white', format="svg")
751
- plt.close()
752
- rcdefaults()
753
-
754
- def _plot_for_classification(shap_values, class_names):
755
- is_multiclass = isinstance(shap_values, list) and len(shap_values) > 1
756
-
757
- if is_multiclass:
758
- for class_shap, class_name in zip(shap_values, class_names):
759
- for plot_type in ["bar", "dot"]:
760
- _create_shap_plot(
761
- shap_values=class_shap,
762
- features=features_to_explain,
763
- save_path=global_save_path / f"SHAP_{sanitized_target_name}_Class{class_name}_{plot_type}.svg",
764
- plot_type=plot_type,
765
- title=f"{model_name} - {target_name} (Class {class_name})"
766
- )
767
- else:
768
- values = shap_values[1] if isinstance(shap_values, list) else shap_values
769
- for plot_type in ["bar", "dot"]:
770
- _create_shap_plot(
771
- shap_values=values,
772
- features=features_to_explain,
773
- save_path=global_save_path / f"SHAP_{sanitized_target_name}_{plot_type}.svg",
774
- plot_type=plot_type,
775
- title=f"{model_name} - {target_name}"
776
- )
777
-
778
- def _plot_for_regression(shap_values):
779
- for plot_type in ["bar", "dot"]:
780
- _create_shap_plot(
781
- shap_values=shap_values,
782
- features=features_to_explain,
783
- save_path=global_save_path / f"SHAP_{sanitized_target_name}_{plot_type}.svg",
784
- plot_type=plot_type,
785
- title=f"{model_name} - {target_name}"
786
- )
787
- #START_O
788
-
789
- explainer = shap.TreeExplainer(model)
790
- shap_values = explainer.shap_values(features_to_explain)
791
-
792
- if task == 'classification':
793
- try:
794
- class_names = model.classes_ if hasattr(model, 'classes_') else list(range(len(shap_values)))
795
- except Exception:
796
- class_names = list(range(len(shap_values)))
797
- _plot_for_classification(shap_values, class_names)
798
- else:
799
- _plot_for_regression(shap_values)
800
-
801
-
802
- # TRAIN TEST PIPELINE
803
- def train_test_pipeline(model, model_name: str, dataset_id: str, task: TaskType,
415
+ # TRAIN EVALUATE PIPELINE
416
+ def train_test_pipeline(model, model_name: str, dataset_id: str, task: Literal["classification", "regression"],
804
417
  train_features: np.ndarray, train_target: np.ndarray,
805
418
  test_features: np.ndarray, test_target: np.ndarray,
806
419
  feature_names: list[str], target_name: str,
807
420
  save_dir: Union[str,Path],
808
- debug: bool=False, save_model: bool=False):
421
+ debug: bool=False, save_model: bool=False,
422
+ generate_learning_curves: bool = False):
809
423
  '''
810
424
  1. Train model.
811
425
  2. Evaluate model.
812
426
  3. SHAP values.
427
+ 4. [Optional] Plot learning curves.
813
428
 
814
429
  Returns: Tuple(Trained model, Test-set Predictions)
815
430
  '''
@@ -823,7 +438,8 @@ def train_test_pipeline(model, model_name: str, dataset_id: str, task: TaskType,
823
438
  _save_model(trained_model=trained_model, model_name=model_name,
824
439
  target_name=target_name, feature_names=feature_names,
825
440
  save_directory=local_save_directory)
826
-
441
+
442
+ # EVALUATION
827
443
  if task == "classification":
828
444
  y_pred = evaluate_model_classification(model=trained_model, model_name=model_name, save_dir=local_save_directory,
829
445
  x_test_scaled=test_features, single_y_test=test_target, target_name=target_name)
@@ -831,6 +447,14 @@ def train_test_pipeline(model, model_name: str, dataset_id: str, task: TaskType,
831
447
  probabilities_or_model=trained_model, model_name=model_name,
832
448
  target_name=target_name, save_directory=local_save_directory,
833
449
  input_features=test_features)
450
+ plot_precision_recall_curve(true_labels=test_target,
451
+ probabilities_or_model=trained_model, model_name=model_name,
452
+ target_name=target_name, save_directory=local_save_directory,
453
+ input_features=test_features)
454
+ plot_calibration_curve(model=trained_model, model_name=model_name,
455
+ save_dir=local_save_directory,
456
+ x_test=test_features, y_test=test_target,
457
+ target_name=target_name)
834
458
  elif task == "regression":
835
459
  y_pred = evaluate_model_regression(model=trained_model, model_name=model_name, save_dir=local_save_directory,
836
460
  x_test_scaled=test_features, single_y_test=test_target, target_name=target_name)
@@ -842,12 +466,21 @@ def train_test_pipeline(model, model_name: str, dataset_id: str, task: TaskType,
842
466
  get_shap_values(model=trained_model, model_name=model_name, save_dir=local_save_directory,
843
467
  features_to_explain=train_features, feature_names=feature_names, target_name=target_name, task=task)
844
468
 
469
+ if generate_learning_curves:
470
+ # Note: We use a *clone* of the initial model object to ensure we don't use the already trained one.
471
+ # The learning_curve function handles the fitting internally.
472
+ initial_model_instance = clone(model)
473
+
474
+ plot_learning_curves(estimator=initial_model_instance, X=train_features, y=train_target,
475
+ task=task, model_name=model_name, target_name=target_name,
476
+ save_directory=local_save_directory)
477
+
845
478
  return trained_model, y_pred
846
479
 
847
- ###### 5. Execution ######
480
+ ###### 4. Execution ######
848
481
  def run_ensemble_pipeline(datasets_dir: Union[str,Path], save_dir: Union[str,Path], target_columns: list[str], model_object: Union[RegressionTreeModels, ClassificationTreeModels],
849
482
  handle_classification_imbalance: HandleImbalanceStrategy=None, save_model: bool=False,
850
- test_size: float=0.2, debug:bool=False):
483
+ test_size: float=0.2, debug:bool=False, generate_learning_curves: bool = False):
851
484
  #Check models
852
485
  if isinstance(model_object, RegressionTreeModels):
853
486
  task = "regression"
@@ -870,7 +503,7 @@ def run_ensemble_pipeline(datasets_dir: Union[str,Path], save_dir: Union[str,Pat
870
503
  #Yield imputed dataset
871
504
  for dataframe, dataframe_name in yield_dataframes_from_dir(datasets_path):
872
505
  #Yield features dataframe and target dataframe
873
- for df_features, df_target, feature_names, target_name in dataset_yielder(df=dataframe, target_cols=target_columns):
506
+ for df_features, df_target, feature_names, target_name in train_dataset_yielder(df=dataframe, target_cols=target_columns):
874
507
  #Dataset pipeline
875
508
  X_train, y_train, X_test, y_test = dataset_pipeline(df_features=df_features, df_target=df_target, task=task,
876
509
  resample_strategy=handle_classification_imbalance,
@@ -883,7 +516,8 @@ def run_ensemble_pipeline(datasets_dir: Union[str,Path], save_dir: Union[str,Pat
883
516
  train_features=X_train, train_target=y_train, # type: ignore
884
517
  test_features=X_test, test_target=y_test,
885
518
  feature_names=feature_names,target_name=target_name,
886
- debug=debug, save_dir=save_path, save_model=save_model)
519
+ debug=debug, save_dir=save_path, save_model=save_model,
520
+ generate_learning_curves=generate_learning_curves)
887
521
 
888
522
  _LOGGER.info("✅ Training and evaluation complete.")
889
523
 
ml_tools/keys.py CHANGED
@@ -1,4 +1,4 @@
1
- class LogKeys:
1
+ class PyTorchLogKeys:
2
2
  """
3
3
  Used internally for ML scripts module.
4
4
 
@@ -14,7 +14,7 @@ class LogKeys:
14
14
  BATCH_SIZE = 'size'
15
15
 
16
16
 
17
- class ModelSaveKeys:
17
+ class EnsembleKeys:
18
18
  """
19
19
  Used internally by ensemble_learning.
20
20
  """