dragon-ml-toolbox 1.4.8__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -5,9 +5,9 @@ import seaborn as sns
5
5
  from IPython import get_ipython
6
6
  from IPython.display import clear_output
7
7
  import time
8
- from typing import Union, Literal, Dict, Tuple, List
9
- import os
10
- from .utilities import sanitize_filename, _script_info
8
+ from typing import Union, Literal, Dict, Tuple, List, Optional
9
+ from pathlib import Path
10
+ from .utilities import sanitize_filename, _script_info, make_fullpath
11
11
  import re
12
12
 
13
13
 
@@ -59,26 +59,48 @@ def summarize_dataframe(df: pd.DataFrame, round_digits: int = 2):
59
59
  return summary
60
60
 
61
61
 
62
- def drop_rows_with_missing_data(df: pd.DataFrame, threshold: float = 0.7) -> pd.DataFrame:
62
+ def drop_rows_with_missing_data(df: pd.DataFrame, targets: Optional[list[str]], threshold: float = 0.7) -> pd.DataFrame:
63
63
  """
64
- Drops rows with more than `threshold` fraction of missing values.
64
+ Drops rows from the DataFrame using a two-stage strategy:
65
+
66
+ 1. If `targets`, remove any row where all target columns are missing.
67
+ 2. Among features, drop those with more than `threshold` fraction of missing values.
65
68
 
66
69
  Parameters:
67
70
  df (pd.DataFrame): The input DataFrame.
68
- threshold (float): Fraction of missing values above which rows are dropped.
71
+ targets (list[str] | None): List of target column names.
72
+ threshold (float): Maximum allowed fraction of missing values in feature columns.
69
73
 
70
74
  Returns:
71
- pd.DataFrame: A new DataFrame without the dropped rows.
75
+ pd.DataFrame: A cleaned DataFrame with problematic rows removed.
72
76
  """
73
- missing_fraction = df.isnull().mean(axis=1)
74
- rows_to_drop = missing_fraction[missing_fraction > threshold].index
75
-
76
- if len(rows_to_drop) > 0:
77
- print(f"Dropping {len(rows_to_drop)} rows with more than {threshold*100:.0f}% missing data.")
77
+ df_clean = df.copy()
78
+
79
+ # Stage 1: Drop rows with all target columns missing
80
+ if targets is not None:
81
+ target_na = df_clean[targets].isnull().all(axis=1)
82
+ if target_na.any():
83
+ print(f"🧹 Dropping {target_na.sum()} rows with all target columns missing.")
84
+ df_clean = df_clean[~target_na]
85
+ else:
86
+ print("✅ No rows with all targets missing.")
78
87
  else:
79
- print(f"No rows have more than {threshold*100:.0f}% missing data.")
88
+ targets = []
89
+
90
+ # Stage 2: Drop rows based on feature column missing values
91
+ feature_cols = [col for col in df_clean.columns if col not in targets]
92
+ if feature_cols:
93
+ feature_na_frac = df_clean[feature_cols].isnull().mean(axis=1)
94
+ rows_to_drop = feature_na_frac[feature_na_frac > threshold].index
95
+ if len(rows_to_drop) > 0:
96
+ print(f"📉 Dropping {len(rows_to_drop)} rows with more than {threshold*100:.0f}% missing feature data.")
97
+ df_clean = df_clean.drop(index=rows_to_drop)
98
+ else:
99
+ print(f"✅ No rows exceed the {threshold*100:.0f}% missing feature data threshold.")
100
+ else:
101
+ print("⚠️ No feature columns available to evaluate.")
80
102
 
81
- return df.drop(index=rows_to_drop)
103
+ return df_clean
82
104
 
83
105
 
84
106
  def split_features_targets(df: pd.DataFrame, targets: list[str]):
@@ -205,13 +227,16 @@ def split_continuous_binary(df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFram
205
227
  return df_cont, df_bin # type: ignore
206
228
 
207
229
 
208
- def plot_correlation_heatmap(df: pd.DataFrame, save_dir: Union[str, None] = None, method: Literal["pearson", "kendall", "spearman"]="pearson", plot_title: str="Correlation Heatmap"):
230
+ def plot_correlation_heatmap(df: pd.DataFrame,
231
+ save_dir: Union[str, Path, None] = None,
232
+ plot_title: str="Correlation Heatmap",
233
+ method: Literal["pearson", "kendall", "spearman"]="pearson"):
209
234
  """
210
235
  Plots a heatmap of pairwise correlations between numeric features in a DataFrame.
211
236
 
212
237
  Args:
213
238
  df (pd.DataFrame): The input dataset.
214
- save_dir (str | None): If provided, the heatmap will be saved to this directory as a svg file.
239
+ save_dir (str | Path | None): If provided, the heatmap will be saved to this directory as a svg file.
215
240
  plot_title: To make different plots, or overwrite existing ones.
216
241
  method (str): Correlation method to use. Must be one of:
217
242
  - 'pearson' (default): measures linear correlation (assumes normally distributed data),
@@ -254,10 +279,13 @@ def plot_correlation_heatmap(df: pd.DataFrame, save_dir: Union[str, None] = None
254
279
  plt.tight_layout()
255
280
 
256
281
  if save_dir:
282
+ save_path = make_fullpath(save_dir, make=True)
257
283
  # sanitize the plot title to save the file
258
284
  plot_title = sanitize_filename(plot_title)
259
- os.makedirs(save_dir, exist_ok=True)
260
- full_path = os.path.join(save_dir, plot_title + ".svg")
285
+ plot_title = plot_title + ".svg"
286
+
287
+ full_path = save_path / plot_title
288
+
261
289
  plt.savefig(full_path, bbox_inches="tight", format='svg')
262
290
  print(f"Saved correlation heatmap: '{plot_title}.svg'")
263
291
 
@@ -322,7 +350,7 @@ def check_value_distributions(df: pd.DataFrame, view_frequencies: bool=True, bin
322
350
  user_input_ = input("Press enter to continue")
323
351
 
324
352
 
325
- def plot_value_distributions(df: pd.DataFrame, save_dir: str, bin_threshold: int=10, skip_cols_with_key: Union[str, None]=None):
353
+ def plot_value_distributions(df: pd.DataFrame, save_dir: Union[str, Path], bin_threshold: int=10, skip_cols_with_key: Union[str, None]=None):
326
354
  """
327
355
  Plots and saves the value distributions for all (or selected) columns in a DataFrame,
328
356
  with adaptive binning for numerical columns when appropriate.
@@ -335,7 +363,7 @@ def plot_value_distributions(df: pd.DataFrame, save_dir: str, bin_threshold: int
335
363
 
336
364
  Args:
337
365
  df (pd.DataFrame): The input DataFrame whose columns are to be analyzed.
338
- save_dir (str): Directory path where the plots will be saved. Will be created if it does not exist.
366
+ save_dir (str | Path): Directory path where the plots will be saved. Will be created if it does not exist.
339
367
  bin_threshold (int): Minimum number of unique values required to trigger binning
340
368
  for numerical columns.
341
369
  skip_cols_with_key (str | None): If provided, any column whose name contains this
@@ -346,8 +374,7 @@ def plot_value_distributions(df: pd.DataFrame, save_dir: str, bin_threshold: int
346
374
  - All non-alphanumeric characters in column names are sanitized for safe file naming.
347
375
  - Colormap is automatically adapted based on the number of categories or bins.
348
376
  """
349
- if save_dir is not None:
350
- os.makedirs(save_dir, exist_ok=True)
377
+ save_path = make_fullpath(save_dir, make=True)
351
378
 
352
379
  dict_to_plot_std = dict()
353
380
  dict_to_plot_freq = dict()
@@ -384,13 +411,12 @@ def plot_value_distributions(df: pd.DataFrame, save_dir: str, bin_threshold: int
384
411
  view_freq = 100 * view_std / view_std.sum() # Percentage
385
412
  # view_freq = df[col].value_counts(normalize=True, bins=10) # relative percentages
386
413
 
387
- if save_dir:
388
- dict_to_plot_std[col] = dict(view_std)
389
- dict_to_plot_freq[col] = dict(view_freq)
390
- saved_plots += 1
414
+ dict_to_plot_std[col] = dict(view_std)
415
+ dict_to_plot_freq[col] = dict(view_freq)
416
+ saved_plots += 1
391
417
 
392
418
  # plot helper
393
- def _plot_helper(dict_: dict, target_dir: str, ylabel: Literal["Frequency", "Counts"], base_fontsize: int=12):
419
+ def _plot_helper(dict_: dict, target_dir: Path, ylabel: Literal["Frequency", "Counts"], base_fontsize: int=12):
394
420
  for col, data in dict_.items():
395
421
  safe_col = sanitize_filename(col)
396
422
 
@@ -412,15 +438,15 @@ def plot_value_distributions(df: pd.DataFrame, save_dir: str, bin_threshold: int
412
438
  plt.gca().set_facecolor('#f9f9f9')
413
439
  plt.tight_layout()
414
440
 
415
- plot_path = os.path.join(target_dir, f"{safe_col}.png")
441
+ plot_path = target_dir / f"{safe_col}.png"
416
442
  plt.savefig(plot_path, dpi=300, bbox_inches="tight")
417
443
  plt.close()
418
444
 
419
445
  # Save plots
420
- freq_dir = os.path.join(save_dir, "Distribution_Frequency")
421
- std_dir = os.path.join(save_dir, "Distribution_Counts")
422
- os.makedirs(freq_dir, exist_ok=True)
423
- os.makedirs(std_dir, exist_ok=True)
446
+ freq_dir = save_path / "Distribution_Frequency"
447
+ std_dir = save_path / "Distribution_Counts"
448
+ freq_dir.mkdir(parents=True, exist_ok=True)
449
+ std_dir.mkdir(parents=True, exist_ok=True)
424
450
  _plot_helper(dict_=dict_to_plot_std, target_dir=std_dir, ylabel="Counts")
425
451
  _plot_helper(dict_=dict_to_plot_freq, target_dir=freq_dir, ylabel="Frequency")
426
452
 
@@ -5,7 +5,7 @@ import matplotlib.pyplot as plt
5
5
  from matplotlib.colors import Colormap
6
6
  from matplotlib import rcdefaults
7
7
 
8
- import os
8
+ from pathlib import Path
9
9
  from typing import Literal, Union, Optional, Iterator, Tuple
10
10
 
11
11
  from imblearn.over_sampling import ADASYN, SMOTE, RandomOverSampler
@@ -19,7 +19,7 @@ from sklearn.model_selection import train_test_split
19
19
  from sklearn.metrics import accuracy_score, classification_report, ConfusionMatrixDisplay, mean_absolute_error, mean_squared_error, r2_score, roc_curve, roc_auc_score
20
20
  import shap
21
21
 
22
- from .utilities import yield_dataframes_from_dir, sanitize_filename, _script_info, serialize_object
22
+ from .utilities import yield_dataframes_from_dir, sanitize_filename, _script_info, serialize_object, make_fullpath
23
23
 
24
24
  import warnings # Ignore warnings
25
25
  warnings.filterwarnings('ignore', category=DeprecationWarning)
@@ -469,30 +469,31 @@ def _train_model(model, train_features, train_target):
469
469
  return model
470
470
 
471
471
  # handle local directories
472
- def _local_directories(model_name: str, dataset_id: str, save_dir: str):
473
- dataset_dir = os.path.join(save_dir, dataset_id)
474
- if not os.path.isdir(dataset_dir):
475
- os.makedirs(dataset_dir)
472
+ def _local_directories(model_name: str, dataset_id: str, save_dir: Union[str,Path]):
473
+ save_path = make_fullpath(save_dir, make=True)
476
474
 
477
- model_dir = os.path.join(dataset_dir, model_name)
478
- if not os.path.isdir(model_dir):
479
- os.makedirs(model_dir)
475
+ dataset_dir = save_path / dataset_id
476
+ dataset_dir.mkdir(parents=True, exist_ok=True)
477
+
478
+ model_dir = dataset_dir / model_name
479
+ model_dir.mkdir(parents=True, exist_ok=True)
480
480
 
481
481
  return model_dir
482
482
 
483
483
  # save model
484
- def _save_model(trained_model, model_name: str, target_name:str, feature_names: list[str], save_directory: str):
484
+ def _save_model(trained_model, model_name: str, target_name:str, feature_names: list[str], save_directory: Union[str,Path]):
485
485
  #Sanitize filenames to save
486
486
  sanitized_target_name = sanitize_filename(target_name)
487
487
  filename = f"{model_name}_{sanitized_target_name}"
488
488
  to_save = {'model': trained_model, 'feature_names': feature_names, 'target_name':target_name}
489
+
489
490
  serialize_object(obj=to_save, save_dir=save_directory, filename=filename, verbose=False, raise_on_error=True)
490
491
 
491
492
  # function to evaluate the model and save metrics (Classification)
492
493
  def evaluate_model_classification(
493
494
  model,
494
495
  model_name: str,
495
- save_dir: str,
496
+ save_dir: Union[str,Path],
496
497
  x_test_scaled: np.ndarray,
497
498
  single_y_test: np.ndarray,
498
499
  target_name: str,
@@ -524,7 +525,7 @@ def evaluate_model_classification(
524
525
  Returns:
525
526
  y_pred: Predicted class labels
526
527
  """
527
- os.makedirs(save_dir, exist_ok=True)
528
+ save_path = make_fullpath(save_dir, make=True)
528
529
 
529
530
  y_pred = model.predict(x_test_scaled)
530
531
  accuracy = accuracy_score(single_y_test, y_pred)
@@ -538,7 +539,7 @@ def evaluate_model_classification(
538
539
 
539
540
  # Save text report
540
541
  sanitized_target_name = sanitize_filename(target_name)
541
- report_path = os.path.join(save_dir, f"Classification_Report_{sanitized_target_name}.txt")
542
+ report_path = save_path / f"Classification_Report_{sanitized_target_name}.txt"
542
543
  with open(report_path, "w") as f:
543
544
  f.write(f"{model_name} - {target_name}\t\tAccuracy: {accuracy:.2f}\n")
544
545
  f.write("Classification Report:\n")
@@ -568,7 +569,7 @@ def evaluate_model_classification(
568
569
  text.set_fontsize(base_fontsize+4)
569
570
 
570
571
  fig.tight_layout()
571
- fig_path = os.path.join(save_dir, f"Confusion_Matrix_{sanitized_target_name}.svg")
572
+ fig_path = save_path / f"Confusion_Matrix_{sanitized_target_name}.svg"
572
573
  fig.savefig(fig_path, format="svg", bbox_inches="tight")
573
574
  plt.close(fig)
574
575
 
@@ -580,7 +581,7 @@ def plot_roc_curve(
580
581
  probabilities_or_model: Union[np.ndarray, xgb.XGBClassifier, lgb.LGBMClassifier, object],
581
582
  model_name: str,
582
583
  target_name: str,
583
- save_directory: str,
584
+ save_directory: Union[str,Path],
584
585
  color: str = "darkorange",
585
586
  figure_size: tuple = (10, 10),
586
587
  linewidth: int = 2,
@@ -594,7 +595,7 @@ def plot_roc_curve(
594
595
  true_labels: np.ndarray of shape (n_samples,), ground truth binary labels (0 or 1).
595
596
  probabilities_or_model: either predicted probabilities (ndarray), or a trained model with attribute `.predict_proba()`.
596
597
  target_name: str, Target name.
597
- save_directory: str, path to directory where figure is saved.
598
+ save_directory: str or Path, path to directory where figure is saved.
598
599
  color: color of the ROC curve. Accepts any valid Matplotlib color specification. Examples:
599
600
  - Named colors: "darkorange", "blue", "red", "green", "black"
600
601
  - Hex codes: "#1f77b4", "#ff7f0e"
@@ -650,17 +651,17 @@ def plot_roc_curve(
650
651
  ax.grid(True)
651
652
 
652
653
  # Save figure
653
- os.makedirs(save_directory, exist_ok=True)
654
+ save_path = make_fullpath(save_directory, make=True)
654
655
  sanitized_target_name = sanitize_filename(target_name)
655
- save_path = os.path.join(save_directory, f"ROC_{sanitized_target_name}.svg")
656
- fig.savefig(save_path, bbox_inches="tight", format="svg")
656
+ full_save_path = save_path / f"ROC_{sanitized_target_name}.svg"
657
+ fig.savefig(full_save_path, bbox_inches="tight", format="svg")
657
658
 
658
659
  return fig
659
660
 
660
661
 
661
662
  # function to evaluate the model and save metrics (Regression)
662
663
  def evaluate_model_regression(model, model_name: str,
663
- save_dir: str,
664
+ save_dir: Union[str,Path],
664
665
  x_test_scaled: np.ndarray, single_y_test: np.ndarray,
665
666
  target_name: str,
666
667
  figure_size: tuple = (12, 8),
@@ -677,7 +678,8 @@ def evaluate_model_regression(model, model_name: str,
677
678
 
678
679
  # Create formatted report
679
680
  sanitized_target_name = sanitize_filename(target_name)
680
- report_path = os.path.join(save_dir, f"Regression_Report_{sanitized_target_name}.txt")
681
+ save_path = make_fullpath(save_dir, make=True)
682
+ report_path = save_path / f"Regression_Report_{sanitized_target_name}.txt"
681
683
  with open(report_path, "w") as f:
682
684
  f.write(f"{model_name} - Regression Performance for '{target_name}'\n\n")
683
685
  f.write(f"Mean Absolute Error (MAE): {mae:.4f}\n")
@@ -695,7 +697,8 @@ def evaluate_model_regression(model, model_name: str,
695
697
  plt.title(f"{model_name} - Residual Plot for {target_name}", fontsize=base_fontsize)
696
698
  plt.grid(True)
697
699
  plt.tight_layout()
698
- plt.savefig(os.path.join(save_dir, f"Residual_Plot_{sanitized_target_name}.svg"), bbox_inches='tight', format="svg")
700
+ residual_path = save_path / f"Residual_Plot_{sanitized_target_name}.svg"
701
+ plt.savefig(residual_path, bbox_inches='tight', format="svg")
699
702
  plt.close()
700
703
 
701
704
  # Create true vs predicted values plot
@@ -708,7 +711,7 @@ def evaluate_model_regression(model, model_name: str,
708
711
  plt.ylabel('Predictions', fontsize=base_fontsize)
709
712
  plt.title(f"{model_name} - True vs Predicted for {target_name}", fontsize=base_fontsize)
710
713
  plt.grid(True)
711
- plot_path = os.path.join(save_dir, f"Regression_Plot_{sanitized_target_name}.svg")
714
+ plot_path = save_path / f"Regression_Plot_{sanitized_target_name}.svg"
712
715
  plt.savefig(plot_path, bbox_inches='tight', format="svg")
713
716
  plt.close()
714
717
 
@@ -719,7 +722,7 @@ def evaluate_model_regression(model, model_name: str,
719
722
  def get_shap_values(
720
723
  model,
721
724
  model_name: str,
722
- save_dir: str,
725
+ save_dir: Union[str, Path],
723
726
  features_to_explain: np.ndarray,
724
727
  feature_names: list[str],
725
728
  target_name: str,
@@ -737,11 +740,12 @@ def get_shap_values(
737
740
  * Use the entire dataset to get the global view.
738
741
 
739
742
  Parameters:
740
- task: 'regression' or 'classification'
743
+ task: 'regression' or 'classification'.
741
744
  features_to_explain: Should match the model's training data format, including scaling.
742
- save_dir: Directory to save visualizations
745
+ save_dir: Directory to save visualizations.
743
746
  """
744
747
  sanitized_target_name = sanitize_filename(target_name)
748
+ global_save_path = make_fullpath(save_dir, make=True)
745
749
 
746
750
  def _apply_plot_style():
747
751
  styles = ['seaborn', 'seaborn-v0_8-darkgrid', 'seaborn-v0_8', 'default']
@@ -759,7 +763,7 @@ def get_shap_values(
759
763
  plt.rc('legend', fontsize=base_fontsize)
760
764
  plt.rc('figure', titlesize=base_fontsize)
761
765
 
762
- def _create_shap_plot(shap_values, features, save_path: str, plot_type: str, title: str):
766
+ def _create_shap_plot(shap_values, features, save_path: Path, plot_type: str, title: str):
763
767
  _apply_plot_style()
764
768
  _configure_rcparams()
765
769
  plt.figure(figsize=figsize)
@@ -804,7 +808,7 @@ def get_shap_values(
804
808
  _create_shap_plot(
805
809
  shap_values=class_shap,
806
810
  features=features_to_explain,
807
- save_path=os.path.join(save_dir, f"SHAP_{sanitized_target_name}_Class{class_name}_{plot_type}.svg"),
811
+ save_path=global_save_path / f"SHAP_{sanitized_target_name}_Class{class_name}_{plot_type}.svg",
808
812
  plot_type=plot_type,
809
813
  title=f"{model_name} - {target_name} (Class {class_name})"
810
814
  )
@@ -814,7 +818,7 @@ def get_shap_values(
814
818
  _create_shap_plot(
815
819
  shap_values=values,
816
820
  features=features_to_explain,
817
- save_path=os.path.join(save_dir, f"SHAP_{sanitized_target_name}_{plot_type}.svg"),
821
+ save_path=global_save_path / f"SHAP_{sanitized_target_name}_{plot_type}.svg",
818
822
  plot_type=plot_type,
819
823
  title=f"{model_name} - {target_name}"
820
824
  )
@@ -824,7 +828,7 @@ def get_shap_values(
824
828
  _create_shap_plot(
825
829
  shap_values=shap_values,
826
830
  features=features_to_explain,
827
- save_path=os.path.join(save_dir, f"SHAP_{sanitized_target_name}_{plot_type}.svg"),
831
+ save_path=global_save_path / f"SHAP_{sanitized_target_name}_{plot_type}.svg",
828
832
  plot_type=plot_type,
829
833
  title=f"{model_name} - {target_name}"
830
834
  )
@@ -848,7 +852,7 @@ def train_test_pipeline(model, model_name: str, dataset_id: str, task: TaskType,
848
852
  train_features: np.ndarray, train_target: np.ndarray,
849
853
  test_features: np.ndarray, test_target: np.ndarray,
850
854
  feature_names: list[str], target_name: str,
851
- save_dir: str,
855
+ save_dir: Union[str,Path],
852
856
  debug: bool=False, save_model: bool=False):
853
857
  '''
854
858
  1. Train model.
@@ -889,7 +893,7 @@ def train_test_pipeline(model, model_name: str, dataset_id: str, task: TaskType,
889
893
  return trained_model, y_pred
890
894
 
891
895
  ###### 5. Execution ######
892
- def run_ensemble_pipeline(datasets_dir: str, save_dir: str, target_columns: list[str], model_object: Union[RegressionTreeModels, ClassificationTreeModels],
896
+ def run_ensemble_pipeline(datasets_dir: Union[str,Path], save_dir: Union[str,Path], target_columns: list[str], model_object: Union[RegressionTreeModels, ClassificationTreeModels],
893
897
  handle_classification_imbalance: HandleImbalanceStrategy=None, save_model: bool=False,
894
898
  test_size: float=0.2, debug:bool=False):
895
899
  #Check models
@@ -907,10 +911,11 @@ def run_ensemble_pipeline(datasets_dir: str, save_dir: str, target_columns: list
907
911
  raise TypeError(f"Unrecognized model {type(model_object)}")
908
912
 
909
913
  #Check paths
910
- _check_paths(datasets_dir, save_dir)
914
+ datasets_path = make_fullpath(datasets_dir)
915
+ save_path = make_fullpath(save_dir, make=True)
911
916
 
912
917
  #Yield imputed dataset
913
- for dataframe, dataframe_name in yield_dataframes_from_dir(datasets_dir):
918
+ for dataframe, dataframe_name in yield_dataframes_from_dir(datasets_path):
914
919
  #Yield features dataframe and target dataframe
915
920
  for df_features, df_target, feature_names, target_name in dataset_yielder(df=dataframe, target_cols=target_columns):
916
921
  #Dataset pipeline
@@ -925,15 +930,8 @@ def run_ensemble_pipeline(datasets_dir: str, save_dir: str, target_columns: list
925
930
  train_features=X_train, train_target=y_train, # type: ignore
926
931
  test_features=X_test, test_target=y_test,
927
932
  feature_names=feature_names,target_name=target_name,
928
- debug=debug, save_dir=save_dir, save_model=save_model)
933
+ debug=debug, save_dir=save_path, save_model=save_model)
929
934
  print("\n✅ Training and evaluation complete.")
930
-
931
-
932
- def _check_paths(datasets_dir: str, save_dir:str):
933
- if not os.path.isdir(save_dir):
934
- os.makedirs(save_dir)
935
- if not os.path.isdir(datasets_dir):
936
- raise IOError(f"Datasets directory '{datasets_dir}' not found.")
937
935
 
938
936
 
939
937
  def info():