dragon-ml-toolbox 1.3.2__py3-none-any.whl → 1.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-1.3.2.dist-info → dragon_ml_toolbox-1.4.1.dist-info}/METADATA +19 -2
- dragon_ml_toolbox-1.4.1.dist-info/RECORD +19 -0
- ml_tools/MICE_imputation.py +24 -6
- ml_tools/VIF_factor.py +224 -0
- ml_tools/data_exploration.py +74 -286
- ml_tools/datasetmaster.py +13 -1
- ml_tools/ensemble_learning.py +128 -129
- ml_tools/handle_excel.py +32 -9
- ml_tools/logger.py +10 -1
- ml_tools/particle_swarm_optimization.py +71 -34
- ml_tools/pytorch_models.py +13 -1
- ml_tools/trainer.py +10 -30
- ml_tools/utilities.py +122 -14
- ml_tools/vision_helpers.py +14 -1
- dragon_ml_toolbox-1.3.2.dist-info/RECORD +0 -18
- {dragon_ml_toolbox-1.3.2.dist-info → dragon_ml_toolbox-1.4.1.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-1.3.2.dist-info → dragon_ml_toolbox-1.4.1.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-1.3.2.dist-info → dragon_ml_toolbox-1.4.1.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-1.3.2.dist-info → dragon_ml_toolbox-1.4.1.dist-info}/top_level.txt +0 -0
ml_tools/data_exploration.py
CHANGED
|
@@ -2,34 +2,30 @@ import pandas as pd
|
|
|
2
2
|
import numpy as np
|
|
3
3
|
import matplotlib.pyplot as plt
|
|
4
4
|
import seaborn as sns
|
|
5
|
-
from statsmodels.stats.outliers_influence import variance_inflation_factor
|
|
6
|
-
from statsmodels.tools.tools import add_constant
|
|
7
5
|
from IPython import get_ipython
|
|
8
6
|
from IPython.display import clear_output
|
|
9
7
|
import time
|
|
10
|
-
from typing import Union, Literal, Dict, Tuple
|
|
8
|
+
from typing import Union, Literal, Dict, Tuple
|
|
11
9
|
import os
|
|
12
10
|
import sys
|
|
13
11
|
import textwrap
|
|
14
|
-
from ml_tools.utilities import sanitize_filename
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
# Keep track of all available
|
|
18
|
-
__all__ = [
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
"compute_vif",
|
|
32
|
-
"drop_vif_based"]
|
|
12
|
+
from ml_tools.utilities import sanitize_filename, _script_info
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# Keep track of all available tools, show using `info()`
|
|
16
|
+
__all__ = [
|
|
17
|
+
"summarize_dataframe",
|
|
18
|
+
"drop_rows_with_missing_data",
|
|
19
|
+
"split_features_targets",
|
|
20
|
+
"show_null_columns",
|
|
21
|
+
"drop_columns_with_missing_data",
|
|
22
|
+
"split_continuous_binary",
|
|
23
|
+
"plot_correlation_heatmap",
|
|
24
|
+
"check_value_distributions",
|
|
25
|
+
"plot_value_distributions",
|
|
26
|
+
"clip_outliers_single",
|
|
27
|
+
"clip_outliers_multi"
|
|
28
|
+
]
|
|
33
29
|
|
|
34
30
|
|
|
35
31
|
def summarize_dataframe(df: pd.DataFrame, round_digits: int = 2):
|
|
@@ -63,34 +59,6 @@ def summarize_dataframe(df: pd.DataFrame, round_digits: int = 2):
|
|
|
63
59
|
return summary
|
|
64
60
|
|
|
65
61
|
|
|
66
|
-
def show_null_columns(df: pd.DataFrame, round_digits: int = 2):
|
|
67
|
-
"""
|
|
68
|
-
Displays a table of columns with missing values, showing both the count and
|
|
69
|
-
percentage of missing entries per column.
|
|
70
|
-
|
|
71
|
-
Parameters:
|
|
72
|
-
df (pd.DataFrame): The input DataFrame.
|
|
73
|
-
round_digits (int): Number of decimal places for the percentage.
|
|
74
|
-
|
|
75
|
-
Returns:
|
|
76
|
-
pd.DataFrame: A DataFrame summarizing missing values in each column.
|
|
77
|
-
"""
|
|
78
|
-
null_counts = df.isnull().sum()
|
|
79
|
-
null_percent = df.isnull().mean() * 100
|
|
80
|
-
|
|
81
|
-
# Filter only columns with at least one null
|
|
82
|
-
mask = null_counts > 0
|
|
83
|
-
null_summary = pd.DataFrame({
|
|
84
|
-
'Missing Count': null_counts[mask],
|
|
85
|
-
'Missing %': null_percent[mask].round(round_digits)
|
|
86
|
-
})
|
|
87
|
-
|
|
88
|
-
# Sort by descending percentage of missing values
|
|
89
|
-
null_summary = null_summary.sort_values(by='Missing %', ascending=False)
|
|
90
|
-
# print(null_summary)
|
|
91
|
-
return null_summary
|
|
92
|
-
|
|
93
|
-
|
|
94
62
|
def drop_rows_with_missing_data(df: pd.DataFrame, threshold: float = 0.7) -> pd.DataFrame:
|
|
95
63
|
"""
|
|
96
64
|
Drops rows with more than `threshold` fraction of missing values.
|
|
@@ -137,6 +105,57 @@ def split_features_targets(df: pd.DataFrame, targets: list[str]):
|
|
|
137
105
|
return df_targets, df_features
|
|
138
106
|
|
|
139
107
|
|
|
108
|
+
def show_null_columns(df: pd.DataFrame, round_digits: int = 2):
|
|
109
|
+
"""
|
|
110
|
+
Displays a table of columns with missing values, showing both the count and
|
|
111
|
+
percentage of missing entries per column.
|
|
112
|
+
|
|
113
|
+
Parameters:
|
|
114
|
+
df (pd.DataFrame): The input DataFrame.
|
|
115
|
+
round_digits (int): Number of decimal places for the percentage.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
pd.DataFrame: A DataFrame summarizing missing values in each column.
|
|
119
|
+
"""
|
|
120
|
+
null_counts = df.isnull().sum()
|
|
121
|
+
null_percent = df.isnull().mean() * 100
|
|
122
|
+
|
|
123
|
+
# Filter only columns with at least one null
|
|
124
|
+
mask = null_counts > 0
|
|
125
|
+
null_summary = pd.DataFrame({
|
|
126
|
+
'Missing Count': null_counts[mask],
|
|
127
|
+
'Missing %': null_percent[mask].round(round_digits)
|
|
128
|
+
})
|
|
129
|
+
|
|
130
|
+
# Sort by descending percentage of missing values
|
|
131
|
+
null_summary = null_summary.sort_values(by='Missing %', ascending=False)
|
|
132
|
+
# print(null_summary)
|
|
133
|
+
return null_summary
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def drop_columns_with_missing_data(df: pd.DataFrame, threshold: float = 0.7) -> pd.DataFrame:
|
|
137
|
+
"""
|
|
138
|
+
Drops columns with more than `threshold` fraction of missing values.
|
|
139
|
+
|
|
140
|
+
Parameters:
|
|
141
|
+
df (pd.DataFrame): The input DataFrame.
|
|
142
|
+
threshold (float): Fraction of missing values above which columns are dropped.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
pd.DataFrame: A new DataFrame without the dropped columns.
|
|
146
|
+
"""
|
|
147
|
+
missing_fraction = df.isnull().mean()
|
|
148
|
+
cols_to_drop = missing_fraction[missing_fraction > threshold].index
|
|
149
|
+
|
|
150
|
+
if len(cols_to_drop) > 0:
|
|
151
|
+
print(f"Dropping columns with more than {threshold*100:.0f}% missing data:")
|
|
152
|
+
print(list(cols_to_drop))
|
|
153
|
+
else:
|
|
154
|
+
print(f"No columns have more than {threshold*100:.0f}% missing data.")
|
|
155
|
+
|
|
156
|
+
return df.drop(columns=cols_to_drop)
|
|
157
|
+
|
|
158
|
+
|
|
140
159
|
def split_continuous_binary(df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
|
141
160
|
"""
|
|
142
161
|
Split DataFrame into two DataFrames: one with continuous columns, one with binary columns.
|
|
@@ -179,29 +198,6 @@ def split_continuous_binary(df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFram
|
|
|
179
198
|
|
|
180
199
|
return df_cont, df_bin # type: ignore
|
|
181
200
|
|
|
182
|
-
|
|
183
|
-
def drop_columns_with_missing_data(df: pd.DataFrame, threshold: float = 0.7) -> pd.DataFrame:
|
|
184
|
-
"""
|
|
185
|
-
Drops columns with more than `threshold` fraction of missing values.
|
|
186
|
-
|
|
187
|
-
Parameters:
|
|
188
|
-
df (pd.DataFrame): The input DataFrame.
|
|
189
|
-
threshold (float): Fraction of missing values above which columns are dropped.
|
|
190
|
-
|
|
191
|
-
Returns:
|
|
192
|
-
pd.DataFrame: A new DataFrame without the dropped columns.
|
|
193
|
-
"""
|
|
194
|
-
missing_fraction = df.isnull().mean()
|
|
195
|
-
cols_to_drop = missing_fraction[missing_fraction > threshold].index
|
|
196
|
-
|
|
197
|
-
if len(cols_to_drop) > 0:
|
|
198
|
-
print(f"Dropping columns with more than {threshold*100:.0f}% missing data:")
|
|
199
|
-
print(list(cols_to_drop))
|
|
200
|
-
else:
|
|
201
|
-
print(f"No columns have more than {threshold*100:.0f}% missing data.")
|
|
202
|
-
|
|
203
|
-
return df.drop(columns=cols_to_drop)
|
|
204
|
-
|
|
205
201
|
|
|
206
202
|
def plot_correlation_heatmap(df: pd.DataFrame, save_dir: Union[str, None] = None, method: Literal["pearson", "kendall", "spearman"]="pearson", plot_title: str="Correlation Heatmap"):
|
|
207
203
|
"""
|
|
@@ -278,7 +274,7 @@ def check_value_distributions(df: pd.DataFrame, view_frequencies: bool=True, bin
|
|
|
278
274
|
Notes:
|
|
279
275
|
- Binning is adaptive: if quantile binning results in ≤ 2 unique bins, raw values are used instead.
|
|
280
276
|
"""
|
|
281
|
-
#
|
|
277
|
+
# cherry-pick columns
|
|
282
278
|
if skip_cols_with_key is not None:
|
|
283
279
|
columns = [col for col in df.columns if skip_cols_with_key not in col]
|
|
284
280
|
else:
|
|
@@ -351,7 +347,7 @@ def plot_value_distributions(df: pd.DataFrame, save_dir: str, bin_threshold: int
|
|
|
351
347
|
dict_to_plot_std = dict()
|
|
352
348
|
dict_to_plot_freq = dict()
|
|
353
349
|
|
|
354
|
-
#
|
|
350
|
+
# cherry-pick columns
|
|
355
351
|
if skip_cols_with_key is not None:
|
|
356
352
|
columns = [col for col in df.columns if skip_cols_with_key not in col]
|
|
357
353
|
else:
|
|
@@ -399,7 +395,7 @@ def plot_value_distributions(df: pd.DataFrame, save_dir: str, bin_threshold: int
|
|
|
399
395
|
labels = data.keys()
|
|
400
396
|
|
|
401
397
|
plt.figure(figsize=(10, 6))
|
|
402
|
-
colors = plt.cm.tab20.colors if len(data) <= 20 else plt.cm.viridis(np.linspace(0, 1, len(data)))
|
|
398
|
+
colors = plt.cm.tab20.colors if len(data) <= 20 else plt.cm.viridis(np.linspace(0, 1, len(data))) # type: ignore
|
|
403
399
|
|
|
404
400
|
plt.bar(labels, data.values(), color=colors[:len(data)], alpha=0.85)
|
|
405
401
|
plt.xlabel("Values", fontsize=base_fontsize)
|
|
@@ -518,218 +514,10 @@ def clip_outliers_multi(
|
|
|
518
514
|
return new_df
|
|
519
515
|
|
|
520
516
|
|
|
521
|
-
def merge_dataframes(
|
|
522
|
-
*dfs: pd.DataFrame,
|
|
523
|
-
reset_index: bool = False,
|
|
524
|
-
direction: Literal["horizontal", "vertical"] = "horizontal"
|
|
525
|
-
) -> pd.DataFrame:
|
|
526
|
-
"""
|
|
527
|
-
Merges multiple DataFrames either horizontally or vertically.
|
|
528
|
-
|
|
529
|
-
Parameters:
|
|
530
|
-
*dfs (pd.DataFrame): Variable number of DataFrames to merge.
|
|
531
|
-
reset_index (bool): Whether to reset index in the final merged DataFrame.
|
|
532
|
-
direction (["horizontal" | "vertical"]):
|
|
533
|
-
- "horizontal": Merge on index, adding columns.
|
|
534
|
-
- "vertical": Append rows; all DataFrames must have identical columns.
|
|
535
|
-
|
|
536
|
-
Returns:
|
|
537
|
-
pd.DataFrame: A single merged DataFrame.
|
|
538
|
-
|
|
539
|
-
Raises:
|
|
540
|
-
ValueError:
|
|
541
|
-
- If fewer than 2 DataFrames are provided.
|
|
542
|
-
- If indexes do not match for horizontal merge.
|
|
543
|
-
- If column names or order differ for vertical merge.
|
|
544
|
-
"""
|
|
545
|
-
if len(dfs) < 2:
|
|
546
|
-
raise ValueError("At least 2 DataFrames must be provided.")
|
|
547
|
-
|
|
548
|
-
for i, df in enumerate(dfs, start=1):
|
|
549
|
-
print(f"DataFrame {i} shape: {df.shape}")
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
if direction == "horizontal":
|
|
553
|
-
reference_index = dfs[0].index
|
|
554
|
-
for i, df in enumerate(dfs, start=1):
|
|
555
|
-
if not df.index.equals(reference_index):
|
|
556
|
-
raise ValueError(f"Indexes do not match: Dataset 1 and Dataset {i}.")
|
|
557
|
-
merged_df = pd.concat(dfs, axis=1)
|
|
558
|
-
|
|
559
|
-
elif direction == "vertical":
|
|
560
|
-
reference_columns = dfs[0].columns
|
|
561
|
-
for i, df in enumerate(dfs, start=1):
|
|
562
|
-
if not df.columns.equals(reference_columns):
|
|
563
|
-
raise ValueError(f"Column names/order do not match: Dataset 1 and Dataset {i}.")
|
|
564
|
-
merged_df = pd.concat(dfs, axis=0)
|
|
565
|
-
|
|
566
|
-
else:
|
|
567
|
-
raise ValueError(f"Invalid merge direction: {direction}")
|
|
568
|
-
|
|
569
|
-
if reset_index:
|
|
570
|
-
merged_df = merged_df.reset_index(drop=True)
|
|
571
|
-
|
|
572
|
-
print(f"Merged DataFrame shape: {merged_df.shape}")
|
|
573
|
-
|
|
574
|
-
return merged_df
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
def save_dataframe(df: pd.DataFrame, save_dir: str, filename: str) -> None:
|
|
578
|
-
"""
|
|
579
|
-
Save a pandas DataFrame to a CSV file.
|
|
580
|
-
|
|
581
|
-
Parameters:
|
|
582
|
-
df: pandas.DataFrame to save
|
|
583
|
-
save_dir: str, directory where the CSV file will be saved.
|
|
584
|
-
filename: str, CSV filename, extension will be added if missing.
|
|
585
|
-
"""
|
|
586
|
-
os.makedirs(save_dir, exist_ok=True)
|
|
587
|
-
|
|
588
|
-
filename = sanitize_filename(filename)
|
|
589
|
-
|
|
590
|
-
if not filename.endswith('.csv'):
|
|
591
|
-
filename += '.csv'
|
|
592
|
-
|
|
593
|
-
output_path = os.path.join(save_dir, filename)
|
|
594
|
-
|
|
595
|
-
df.to_csv(output_path, index=False, encoding='utf-8')
|
|
596
|
-
print(f"Saved file: '{filename}'")
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
def compute_vif(
|
|
600
|
-
df: pd.DataFrame,
|
|
601
|
-
features: Optional[list[str]] = None,
|
|
602
|
-
ignore_cols: Optional[list[str]] = None,
|
|
603
|
-
plot: bool = True,
|
|
604
|
-
save_dir: Union[str, None] = None
|
|
605
|
-
) -> pd.DataFrame:
|
|
606
|
-
"""
|
|
607
|
-
Computes Variance Inflation Factors (VIF) for numeric features, optionally plots and saves the results.
|
|
608
|
-
|
|
609
|
-
There cannot be empty values in the dataset.
|
|
610
|
-
|
|
611
|
-
Args:
|
|
612
|
-
df (pd.DataFrame): The input DataFrame.
|
|
613
|
-
features (list[str] | None): Optional list of column names to evaluate. Defaults to all numeric columns.
|
|
614
|
-
ignore_cols (list[str] | None): Optional list of column names to ignore.
|
|
615
|
-
plot (bool): Whether to display a barplot of VIF values.
|
|
616
|
-
save_dir (str | None): Directory to save the plot as SVG. If None, plot is not saved.
|
|
617
|
-
|
|
618
|
-
Returns:
|
|
619
|
-
pd.DataFrame: DataFrame with features and corresponding VIF values, sorted descending.
|
|
620
|
-
|
|
621
|
-
NOTE:
|
|
622
|
-
**Variance Inflation Factor (VIF)** quantifies the degree of multicollinearity among features in a dataset.
|
|
623
|
-
A VIF value indicates how much the variance of a regression coefficient is inflated due to linear dependence with other features.
|
|
624
|
-
A VIF of 1 suggests no correlation, values between 1 and 5 indicate moderate correlation, and values greater than 10 typically signal high multicollinearity, which may distort model interpretation and degrade performance.
|
|
625
|
-
|
|
626
|
-
"""
|
|
627
|
-
if features is None:
|
|
628
|
-
features = df.select_dtypes(include='number').columns.tolist()
|
|
629
|
-
|
|
630
|
-
if ignore_cols is not None:
|
|
631
|
-
missing = set(ignore_cols) - set(features)
|
|
632
|
-
if missing:
|
|
633
|
-
raise ValueError(f"The following 'columns to ignore' are not in the Dataframe:\n{missing}")
|
|
634
|
-
features = [f for f in features if f not in ignore_cols]
|
|
635
|
-
|
|
636
|
-
X = df[features].copy()
|
|
637
|
-
X = add_constant(X, has_constant='add')
|
|
638
|
-
|
|
639
|
-
vif_data = pd.DataFrame()
|
|
640
|
-
vif_data["feature"] = X.columns
|
|
641
|
-
vif_data["VIF"] = [variance_inflation_factor(X.values, i) for i in range(X.shape[1])]
|
|
642
|
-
|
|
643
|
-
# Drop the constant column
|
|
644
|
-
vif_data = vif_data[vif_data["feature"] != "const"]
|
|
645
|
-
vif_data = vif_data.sort_values(by="VIF", ascending=False).reset_index(drop=True) # type: ignore
|
|
646
|
-
|
|
647
|
-
# Add color coding based on thresholds
|
|
648
|
-
def vif_color(v: float) -> str:
|
|
649
|
-
if v > 10:
|
|
650
|
-
return "red"
|
|
651
|
-
elif v > 5:
|
|
652
|
-
return "gold"
|
|
653
|
-
else:
|
|
654
|
-
return "green"
|
|
655
|
-
|
|
656
|
-
vif_data["color"] = vif_data["VIF"].apply(vif_color)
|
|
657
|
-
|
|
658
|
-
# Plot
|
|
659
|
-
if plot or save_dir:
|
|
660
|
-
plt.figure(figsize=(10, 6))
|
|
661
|
-
bars = plt.barh(
|
|
662
|
-
vif_data["feature"],
|
|
663
|
-
vif_data["VIF"],
|
|
664
|
-
color=vif_data["color"],
|
|
665
|
-
edgecolor='black'
|
|
666
|
-
)
|
|
667
|
-
plt.title("Variance Inflation Factor (VIF) per Feature")
|
|
668
|
-
plt.xlabel("VIF")
|
|
669
|
-
plt.axvline(x=5, color='gold', linestyle='--', label='VIF = 5')
|
|
670
|
-
plt.axvline(x=10, color='red', linestyle='--', label='VIF = 10')
|
|
671
|
-
plt.legend(loc='lower right')
|
|
672
|
-
plt.gca().invert_yaxis()
|
|
673
|
-
plt.grid(axis='x', linestyle='--', alpha=0.5)
|
|
674
|
-
plt.tight_layout()
|
|
675
|
-
|
|
676
|
-
if save_dir:
|
|
677
|
-
os.makedirs(save_dir, exist_ok=True)
|
|
678
|
-
save_path = os.path.join(save_dir, "VIF_plot.svg")
|
|
679
|
-
plt.savefig(save_path, format='svg', bbox_inches='tight')
|
|
680
|
-
print(f"Saved VIF plot to: {save_path}")
|
|
681
|
-
|
|
682
|
-
if plot:
|
|
683
|
-
plt.show()
|
|
684
|
-
plt.close()
|
|
685
|
-
|
|
686
|
-
return vif_data.drop(columns="color")
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
def drop_vif_based(df: pd.DataFrame, vif_df: pd.DataFrame, threshold: float = 10.0) -> pd.DataFrame:
|
|
690
|
-
"""
|
|
691
|
-
Drops features from the original DataFrame based on their VIF values exceeding a given threshold.
|
|
692
|
-
|
|
693
|
-
Args:
|
|
694
|
-
df (pd.DataFrame): Original DataFrame containing the features.
|
|
695
|
-
vif_df (pd.DataFrame): DataFrame with 'feature' and 'VIF' columns as returned by `compute_vif()`.
|
|
696
|
-
threshold (float): VIF threshold above which features will be dropped.
|
|
697
|
-
|
|
698
|
-
Returns:
|
|
699
|
-
pd.DataFrame: A new DataFrame with high-VIF features removed.
|
|
700
|
-
"""
|
|
701
|
-
# Ensure expected structure
|
|
702
|
-
if 'feature' not in vif_df.columns or 'VIF' not in vif_df.columns:
|
|
703
|
-
raise ValueError("`vif_df` must contain 'feature' and 'VIF' columns.")
|
|
704
|
-
|
|
705
|
-
# Identify features to drop
|
|
706
|
-
to_drop = vif_df[vif_df["VIF"] > threshold]["feature"].tolist()
|
|
707
|
-
print(f"Dropping {len(to_drop)} feature(s) with VIF > {threshold}: {to_drop}")
|
|
708
|
-
|
|
709
|
-
return df.drop(columns=to_drop, errors="ignore")
|
|
710
|
-
|
|
711
|
-
|
|
712
517
|
def _is_notebook():
|
|
713
518
|
return get_ipython() is not None
|
|
714
519
|
|
|
715
520
|
|
|
716
|
-
def info(
|
|
717
|
-
|
|
718
|
-
List available functions and their descriptions.
|
|
719
|
-
"""
|
|
720
|
-
print("Available functions for data exploration:")
|
|
721
|
-
if full_info:
|
|
722
|
-
module = sys.modules[__name__]
|
|
723
|
-
for name in __all__:
|
|
724
|
-
obj = getattr(module, name, None)
|
|
725
|
-
if callable(obj):
|
|
726
|
-
doc = obj.__doc__ or "No docstring provided."
|
|
727
|
-
formatted_doc = textwrap.indent(textwrap.dedent(doc.strip()), prefix=" ")
|
|
728
|
-
print(f"\n{name}:\n{formatted_doc}")
|
|
729
|
-
else:
|
|
730
|
-
for i, name in enumerate(__all__, start=1):
|
|
731
|
-
print(f"{i} - {name}")
|
|
732
|
-
|
|
521
|
+
def info():
|
|
522
|
+
_script_info(__all__)
|
|
733
523
|
|
|
734
|
-
if __name__ == "__main__":
|
|
735
|
-
info()
|
ml_tools/datasetmaster.py
CHANGED
|
@@ -11,6 +11,15 @@ from PIL import Image
|
|
|
11
11
|
from torchvision.datasets import ImageFolder
|
|
12
12
|
from torchvision import transforms
|
|
13
13
|
import matplotlib.pyplot as plt
|
|
14
|
+
from .utilities import _script_info
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"DatasetMaker",
|
|
19
|
+
"PytorchDataset",
|
|
20
|
+
"make_vision_dataset",
|
|
21
|
+
"SequenceDataset",
|
|
22
|
+
]
|
|
14
23
|
|
|
15
24
|
|
|
16
25
|
class DatasetMaker():
|
|
@@ -592,4 +601,7 @@ class SequenceDataset():
|
|
|
592
601
|
|
|
593
602
|
def __len__(self):
|
|
594
603
|
return f"Train: {len(self.train_dataset)}, Test: {len(self.test_dataset)}"
|
|
595
|
-
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
def info():
|
|
607
|
+
_script_info(__all__)
|