dragon-ml-toolbox 20.7.1__py3-none-any.whl → 20.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-20.7.1.dist-info → dragon_ml_toolbox-20.8.0.dist-info}/METADATA +3 -1
- {dragon_ml_toolbox-20.7.1.dist-info → dragon_ml_toolbox-20.8.0.dist-info}/RECORD +16 -12
- ml_tools/ML_configuration/_metrics.py +17 -10
- ml_tools/ML_evaluation/_classification.py +79 -7
- ml_tools/data_exploration/__init__.py +5 -1
- ml_tools/data_exploration/_analysis.py +149 -0
- ml_tools/data_exploration/_features.py +76 -0
- ml_tools/keys/_keys.py +1 -0
- ml_tools/resampling/__init__.py +19 -0
- ml_tools/resampling/_base_resampler.py +49 -0
- ml_tools/resampling/_multi_resampling.py +184 -0
- ml_tools/resampling/_single_resampling.py +113 -0
- {dragon_ml_toolbox-20.7.1.dist-info → dragon_ml_toolbox-20.8.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-20.7.1.dist-info → dragon_ml_toolbox-20.8.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-20.7.1.dist-info → dragon_ml_toolbox-20.8.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-20.7.1.dist-info → dragon_ml_toolbox-20.8.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dragon-ml-toolbox
|
|
3
|
-
Version: 20.
|
|
3
|
+
Version: 20.8.0
|
|
4
4
|
Summary: Complete pipelines and helper tools for data science and machine learning projects.
|
|
5
5
|
Author-email: Karl Luigi Loza Vidaurre <luigiloza@gmail.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -174,6 +174,7 @@ ML_vision_transformers
|
|
|
174
174
|
optimization_tools
|
|
175
175
|
path_manager
|
|
176
176
|
plot_fonts
|
|
177
|
+
resampling
|
|
177
178
|
schema
|
|
178
179
|
serde
|
|
179
180
|
SQL
|
|
@@ -206,6 +207,7 @@ optimization_tools
|
|
|
206
207
|
path_manager
|
|
207
208
|
plot_fonts
|
|
208
209
|
PSO_optimization
|
|
210
|
+
resampling
|
|
209
211
|
schema
|
|
210
212
|
serde
|
|
211
213
|
SQL
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
dragon_ml_toolbox-20.
|
|
2
|
-
dragon_ml_toolbox-20.
|
|
1
|
+
dragon_ml_toolbox-20.8.0.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
|
|
2
|
+
dragon_ml_toolbox-20.8.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=0-HBRMMgKuwtGy6nMJZvIn1fLxhx_ksyyVB2U_iyYZU,2818
|
|
3
3
|
ml_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
4
|
ml_tools/constants.py,sha256=3br5Rk9cL2IUo638eJuMOGdbGQaWssaUecYEvSeRBLM,3322
|
|
5
5
|
ml_tools/ETL_cleaning/__init__.py,sha256=gLRHF-qzwpqKTvbbn9chIQELeUDh_XGpBRX28j-5IqI,545
|
|
@@ -30,7 +30,7 @@ ml_tools/ML_chain/_update_schema.py,sha256=z1Us7lv6hy6GwSu1mcid50Jmqq3sh91hMQ0Ln
|
|
|
30
30
|
ml_tools/ML_configuration/__init__.py,sha256=ogktFnYxz5jWJkhHS4DVaMldHkt3lT2gw9jx5PQ3d78,2755
|
|
31
31
|
ml_tools/ML_configuration/_base_model_config.py,sha256=95L3IfobNFMtnNr79zYpDGerC1q1v7M05tWZvTS2cwE,2247
|
|
32
32
|
ml_tools/ML_configuration/_finalize.py,sha256=l_n13bLu0avMdJ8hNRrH8V_wOBQZM1UGsTydKBkTysM,15047
|
|
33
|
-
ml_tools/ML_configuration/_metrics.py,sha256=
|
|
33
|
+
ml_tools/ML_configuration/_metrics.py,sha256=KJM7HQeoEmJUUUrxNa4wYf2N9NawGPJoy7AGdNO3gxQ,24059
|
|
34
34
|
ml_tools/ML_configuration/_models.py,sha256=lvuuqvD6DWUzOa3i06NZfrdfOi9bu2e26T_QO6BGMSw,7629
|
|
35
35
|
ml_tools/ML_configuration/_training.py,sha256=_M_TwouHFNbGrZQtQNAvyG_poSVpmN99cbyUonZsHhk,8969
|
|
36
36
|
ml_tools/ML_datasetmaster/__init__.py,sha256=UltQzuXnlXVCkD-aeA5TW4IcMVLnQf1_aglawg4WyrI,580
|
|
@@ -39,7 +39,7 @@ ml_tools/ML_datasetmaster/_datasetmaster.py,sha256=Oy2UE3YJpKTaFwQF5TkQLgLB54-BF
|
|
|
39
39
|
ml_tools/ML_datasetmaster/_sequence_datasetmaster.py,sha256=cW3fuILZWs-7Yuo4T2fgGfTC4vwho3Gp4ohIKJYS7O0,18452
|
|
40
40
|
ml_tools/ML_datasetmaster/_vision_datasetmaster.py,sha256=kvSqXYeNBN1JSRfSEEXYeIcsqy9HsJAl_EwFWClqlsw,67025
|
|
41
41
|
ml_tools/ML_evaluation/__init__.py,sha256=e3c8JNP0tt4Kxc7QSQpGcOgrxf8JAucH4UkJvJxUL2E,1122
|
|
42
|
-
ml_tools/ML_evaluation/_classification.py,sha256=
|
|
42
|
+
ml_tools/ML_evaluation/_classification.py,sha256=Te5ckLfBCUyb3QO9vZ_mlJF5wS5LoajXC54k1Fkct-U,33938
|
|
43
43
|
ml_tools/ML_evaluation/_feature_importance.py,sha256=mTwi3LKom_axu6UFKunELj30APDdhG9GQC2w7I9mYhI,17137
|
|
44
44
|
ml_tools/ML_evaluation/_loss.py,sha256=1a4O25i3Ya_3naNZNL7ELLUL46BY86g1scA7d7q2UFM,3625
|
|
45
45
|
ml_tools/ML_evaluation/_regression.py,sha256=hnT2B2_6AnQ7aA7uk-X2lZL9G5JFGCduDXyZbr1gFCA,11037
|
|
@@ -103,10 +103,10 @@ ml_tools/_core/__init__.py,sha256=m-VP0RW0tOTm9N5NI3kFNcpM7WtVgs0RK9pK3ZJRZQQ,14
|
|
|
103
103
|
ml_tools/_core/_logger.py,sha256=xzhn_FouMDRVNwXGBGlPC9Ruq6i5uCrmNaS5jesguMU,4972
|
|
104
104
|
ml_tools/_core/_schema_load_ops.py,sha256=KLs9vBzANz5ESe2wlP-C41N4VlgGil-ywcfvWKSOGss,1551
|
|
105
105
|
ml_tools/_core/_script_info.py,sha256=LtFGt10gEvCnhIRMKJPi2yXkiGLcdr7lE-oIP2XGHzQ,234
|
|
106
|
-
ml_tools/data_exploration/__init__.py,sha256=
|
|
107
|
-
ml_tools/data_exploration/_analysis.py,sha256=
|
|
106
|
+
ml_tools/data_exploration/__init__.py,sha256=efUBsruHL56B429tUadl3PdG73zAF639Y430uMQRfko,1917
|
|
107
|
+
ml_tools/data_exploration/_analysis.py,sha256=PJNrEBz5ZZXHoUlQ6fh9Y86nzPQrLpVPv2Ye4NfOxgs,14181
|
|
108
108
|
ml_tools/data_exploration/_cleaning.py,sha256=pAZOXgGK35j7O8q6cnyTwYK1GLNnD04A8p2fSyMB1mg,20906
|
|
109
|
-
ml_tools/data_exploration/_features.py,sha256=
|
|
109
|
+
ml_tools/data_exploration/_features.py,sha256=Z1noJfDxBzFRfusFp6NlpLF2NItuZuzFHq4ssWFqny4,26273
|
|
110
110
|
ml_tools/data_exploration/_plotting.py,sha256=zH1dPcIoAlOuww23xIoBCsQOAshPPv9OyGposOA2RvI,19883
|
|
111
111
|
ml_tools/data_exploration/_schema_ops.py,sha256=Fd6fBGGv4OpxmJ1HG9pith6QL90z0tzssCvzkQxlEEQ,11083
|
|
112
112
|
ml_tools/ensemble_evaluation/__init__.py,sha256=t4Gr8EGEk8RLatyc92-S0BzbQvdvodzoF-qDAH2qjVg,546
|
|
@@ -118,7 +118,7 @@ ml_tools/ensemble_learning/_ensemble_learning.py,sha256=MHDZBR20_nStlSSeThFI3bSu
|
|
|
118
118
|
ml_tools/excel_handler/__init__.py,sha256=AaWM3n_dqBhJLTs3OEA57ex5YykKXNOwVCyHlVsdnqI,530
|
|
119
119
|
ml_tools/excel_handler/_excel_handler.py,sha256=TODudmeQgDSdxUKzLfAzizs--VL-g8WxDOfQ4sgxxLs,13965
|
|
120
120
|
ml_tools/keys/__init__.py,sha256=-0c2pmrhyfROc-oQpEjJGLBMhSagA3CyFijQaaqZRqU,399
|
|
121
|
-
ml_tools/keys/_keys.py,sha256=
|
|
121
|
+
ml_tools/keys/_keys.py,sha256=jBhw99SRTlBkb9EFMDLZA86_kaHT4YLxkljDYRCTarE,9389
|
|
122
122
|
ml_tools/math_utilities/__init__.py,sha256=K7Obkkc4rPKj4EbRZf1BsXHfiCg7FXYv_aN9Yc2Z_Vg,400
|
|
123
123
|
ml_tools/math_utilities/_math_utilities.py,sha256=BYHIVcM9tuKIhVrkgLLiM5QalJ39zx7dXYy_M9aGgiM,9012
|
|
124
124
|
ml_tools/optimization_tools/__init__.py,sha256=KD8JXpfGuPndO4AHnjJGu6uV1GRwhOfboD0KZV45kzw,658
|
|
@@ -129,6 +129,10 @@ ml_tools/path_manager/_dragonmanager.py,sha256=q9wHTKPmdzywEz6N14ipUoeR3MmW0bzB4
|
|
|
129
129
|
ml_tools/path_manager/_path_tools.py,sha256=LcZE31QlkzZWUR8g1MW_N_mPY2DpKBJLA45VJz7ZYsw,11905
|
|
130
130
|
ml_tools/plot_fonts/__init__.py,sha256=KIxXRCjQ3SliEoLhEcqs7zDVZbVTn38bmSdL-yR1Q2w,187
|
|
131
131
|
ml_tools/plot_fonts/_plot_fonts.py,sha256=mfjXNT9P59ymHoTI85Q8CcvfxfK5BIFBWtTZH-hNIC4,2209
|
|
132
|
+
ml_tools/resampling/__init__.py,sha256=WB1YlNQgOIdSSQn-9eCIaiB0AHLSCkziFufqa-1QBG0,278
|
|
133
|
+
ml_tools/resampling/_base_resampler.py,sha256=8IqkEJ7uiAiC9bqbKfsC-5vIvrN3EwH7lLVDlRKQzM8,1617
|
|
134
|
+
ml_tools/resampling/_multi_resampling.py,sha256=m_iVvXPAu3p_EoBt2VZpgjhPLY1LmKa8fGtQo5E0pWk,7199
|
|
135
|
+
ml_tools/resampling/_single_resampling.py,sha256=zKL4Br7Lm4Jq90X-ewQ6AKTsP923bq9RIMnTxIxtXBc,3896
|
|
132
136
|
ml_tools/schema/__init__.py,sha256=K6uiZ9f0GCQ7etw1yl2-dQVLhU7RkL3KHesO3HNX6v4,334
|
|
133
137
|
ml_tools/schema/_feature_schema.py,sha256=MuPf6Nf7tDhUTGyX7tcFHZh-lLSNsJkLmlf9IxdF4O4,9660
|
|
134
138
|
ml_tools/schema/_gui_schema.py,sha256=IVwN4THAdFrvh2TpV4SFd_zlzMX3eioF-w-qcSVTndE,7245
|
|
@@ -138,7 +142,7 @@ ml_tools/utilities/__init__.py,sha256=h4lE3SQstg-opcQj6QSKhu-HkqSbmHExsWoM9vC5D9
|
|
|
138
142
|
ml_tools/utilities/_translate.py,sha256=U8hRPa3PmTpIf9n9yR3gBGmp_hkcsjQLwjAHSHc0WHs,10325
|
|
139
143
|
ml_tools/utilities/_utility_save_load.py,sha256=EFvFaTaHahDQWdJWZr-j7cHqRbG_Xrpc96228JhV-bs,16773
|
|
140
144
|
ml_tools/utilities/_utility_tools.py,sha256=bN0J9d1S0W5wNzNntBWqDsJcEAK7-1OgQg3X2fwXns0,6918
|
|
141
|
-
dragon_ml_toolbox-20.
|
|
142
|
-
dragon_ml_toolbox-20.
|
|
143
|
-
dragon_ml_toolbox-20.
|
|
144
|
-
dragon_ml_toolbox-20.
|
|
145
|
+
dragon_ml_toolbox-20.8.0.dist-info/METADATA,sha256=EVzUhpCzHarcTicuqc_t4prSuJdXGuCppSX7wnIv1JY,7888
|
|
146
|
+
dragon_ml_toolbox-20.8.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
147
|
+
dragon_ml_toolbox-20.8.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
|
|
148
|
+
dragon_ml_toolbox-20.8.0.dist-info/RECORD,,
|
|
@@ -98,10 +98,11 @@ class _BaseMultiLabelFormat:
|
|
|
98
98
|
cmap: str = "BuGn",
|
|
99
99
|
ROC_PR_line: str='darkorange',
|
|
100
100
|
calibration_bins: Union[int, Literal['auto']]='auto',
|
|
101
|
-
font_size: int =
|
|
102
|
-
xtick_size: int=
|
|
103
|
-
ytick_size: int=
|
|
104
|
-
legend_size: int=
|
|
101
|
+
font_size: int = 26,
|
|
102
|
+
xtick_size: int=22,
|
|
103
|
+
ytick_size: int=22,
|
|
104
|
+
legend_size: int=26,
|
|
105
|
+
cm_font_size: int=26) -> None:
|
|
105
106
|
"""
|
|
106
107
|
Initializes the formatting configuration for multi-label classification metrics.
|
|
107
108
|
|
|
@@ -127,6 +128,8 @@ class _BaseMultiLabelFormat:
|
|
|
127
128
|
|
|
128
129
|
legend_size (int): Font size for plot legends.
|
|
129
130
|
|
|
131
|
+
cm_font_size (int): Font size for the confusion matrix.
|
|
132
|
+
|
|
130
133
|
<br>
|
|
131
134
|
|
|
132
135
|
### [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
|
|
@@ -142,6 +145,7 @@ class _BaseMultiLabelFormat:
|
|
|
142
145
|
self.xtick_size = xtick_size
|
|
143
146
|
self.ytick_size = ytick_size
|
|
144
147
|
self.legend_size = legend_size
|
|
148
|
+
self.cm_font_size = cm_font_size
|
|
145
149
|
|
|
146
150
|
def __repr__(self) -> str:
|
|
147
151
|
parts = [
|
|
@@ -151,7 +155,8 @@ class _BaseMultiLabelFormat:
|
|
|
151
155
|
f"font_size={self.font_size}",
|
|
152
156
|
f"xtick_size={self.xtick_size}",
|
|
153
157
|
f"ytick_size={self.ytick_size}",
|
|
154
|
-
f"legend_size={self.legend_size}"
|
|
158
|
+
f"legend_size={self.legend_size}",
|
|
159
|
+
f"cm_font_size={self.cm_font_size}"
|
|
155
160
|
]
|
|
156
161
|
return f"{self.__class__.__name__}({', '.join(parts)})"
|
|
157
162
|
|
|
@@ -520,10 +525,11 @@ class FormatMultiLabelBinaryClassificationMetrics(_BaseMultiLabelFormat):
|
|
|
520
525
|
cmap: str = "BuGn",
|
|
521
526
|
ROC_PR_line: str='darkorange',
|
|
522
527
|
calibration_bins: Union[int, Literal['auto']]='auto',
|
|
523
|
-
font_size: int =
|
|
524
|
-
xtick_size: int=
|
|
525
|
-
ytick_size: int=
|
|
526
|
-
legend_size: int=
|
|
528
|
+
font_size: int = 26,
|
|
529
|
+
xtick_size: int=22,
|
|
530
|
+
ytick_size: int=22,
|
|
531
|
+
legend_size: int=26,
|
|
532
|
+
cm_font_size: int=26
|
|
527
533
|
) -> None:
|
|
528
534
|
super().__init__(cmap=cmap,
|
|
529
535
|
ROC_PR_line=ROC_PR_line,
|
|
@@ -531,7 +537,8 @@ class FormatMultiLabelBinaryClassificationMetrics(_BaseMultiLabelFormat):
|
|
|
531
537
|
font_size=font_size,
|
|
532
538
|
xtick_size=xtick_size,
|
|
533
539
|
ytick_size=ytick_size,
|
|
534
|
-
legend_size=legend_size
|
|
540
|
+
legend_size=legend_size,
|
|
541
|
+
cm_font_size=cm_font_size)
|
|
535
542
|
|
|
536
543
|
|
|
537
544
|
# Segmentation
|
|
@@ -481,6 +481,10 @@ def multi_label_classification_metrics(
|
|
|
481
481
|
ytick_size = format_config.ytick_size
|
|
482
482
|
legend_size = format_config.legend_size
|
|
483
483
|
base_font_size = format_config.font_size
|
|
484
|
+
|
|
485
|
+
# config font size for heatmap
|
|
486
|
+
cm_font_size = format_config.cm_font_size
|
|
487
|
+
cm_tick_size = cm_font_size - 4
|
|
484
488
|
|
|
485
489
|
# --- Calculate and Save Overall Metrics (using y_pred) ---
|
|
486
490
|
h_loss = hamming_loss(y_true, y_pred)
|
|
@@ -488,7 +492,7 @@ def multi_label_classification_metrics(
|
|
|
488
492
|
j_score_macro = jaccard_score(y_true, y_pred, average='macro')
|
|
489
493
|
|
|
490
494
|
overall_report = (
|
|
491
|
-
f"Overall Multi-Label Metrics:\n"
|
|
495
|
+
f"Overall Multi-Label Metrics:\n"
|
|
492
496
|
f"--------------------------------------------------\n"
|
|
493
497
|
f"Hamming Loss: {h_loss:.4f}\n"
|
|
494
498
|
f"Jaccard Score (micro): {j_score_micro:.4f}\n"
|
|
@@ -498,14 +502,82 @@ def multi_label_classification_metrics(
|
|
|
498
502
|
# print(overall_report)
|
|
499
503
|
overall_report_path = save_dir_path / "classification_report.txt"
|
|
500
504
|
overall_report_path.write_text(overall_report)
|
|
505
|
+
|
|
506
|
+
# --- Save Classification Report Heatmap (Multi-label) ---
|
|
507
|
+
try:
|
|
508
|
+
# Generate full report as dict
|
|
509
|
+
full_report_dict = classification_report(y_true, y_pred, target_names=target_names, output_dict=True)
|
|
510
|
+
report_df = pd.DataFrame(full_report_dict)
|
|
511
|
+
|
|
512
|
+
# Cleanup
|
|
513
|
+
# Remove 'accuracy' column if it exists
|
|
514
|
+
report_df = report_df.drop(columns=['accuracy'], errors='ignore')
|
|
515
|
+
|
|
516
|
+
# Remove 'support' row explicitly
|
|
517
|
+
if 'support' in report_df.index:
|
|
518
|
+
report_df = report_df.drop(index='support')
|
|
519
|
+
|
|
520
|
+
# Transpose: Rows = Classes/Averages, Cols = Metrics
|
|
521
|
+
plot_df = report_df.T
|
|
522
|
+
|
|
523
|
+
# Dynamic Height
|
|
524
|
+
fig_height = max(5.0, len(plot_df.index) * 0.5 + 4.0)
|
|
525
|
+
fig_width = 8.0
|
|
526
|
+
|
|
527
|
+
fig_heat, ax_heat = plt.subplots(figsize=(fig_width, fig_height), dpi=_EvaluationConfig.DPI)
|
|
528
|
+
|
|
529
|
+
# Plot
|
|
530
|
+
sns.heatmap(plot_df,
|
|
531
|
+
annot=True,
|
|
532
|
+
cmap=format_config.cmap,
|
|
533
|
+
fmt='.2f',
|
|
534
|
+
vmin=0.0,
|
|
535
|
+
vmax=1.0,
|
|
536
|
+
cbar_kws={'shrink': 0.9})
|
|
537
|
+
|
|
538
|
+
ax_heat.set_title("Classification Report Heatmap", pad=_EvaluationConfig.LABEL_PADDING, fontsize=cm_font_size)
|
|
539
|
+
|
|
540
|
+
# manually increase the font size of the elements
|
|
541
|
+
for text in ax_heat.texts:
|
|
542
|
+
text.set_fontsize(cm_tick_size)
|
|
543
|
+
|
|
544
|
+
cbar = ax_heat.collections[0].colorbar
|
|
545
|
+
cbar.ax.tick_params(labelsize=cm_tick_size - 4) # type: ignore
|
|
546
|
+
|
|
547
|
+
ax_heat.tick_params(axis='x', labelsize=cm_tick_size, pad=_EvaluationConfig.LABEL_PADDING)
|
|
548
|
+
ax_heat.tick_params(axis='y', labelsize=cm_tick_size, pad=_EvaluationConfig.LABEL_PADDING, rotation=0)
|
|
549
|
+
|
|
550
|
+
plt.tight_layout()
|
|
551
|
+
heatmap_path = save_dir_path / "classification_report_heatmap.svg"
|
|
552
|
+
plt.savefig(heatmap_path)
|
|
553
|
+
_LOGGER.info(f"📊 Report heatmap saved as '{heatmap_path.name}'")
|
|
554
|
+
plt.close(fig_heat)
|
|
555
|
+
|
|
556
|
+
except Exception as e:
|
|
557
|
+
_LOGGER.error(f"Could not generate multi-label classification report heatmap: {e}")
|
|
501
558
|
|
|
502
559
|
# --- Per-Label Metrics and Plots ---
|
|
503
560
|
for i, name in enumerate(target_names):
|
|
504
|
-
|
|
561
|
+
# strip whitespace from name
|
|
562
|
+
name = name.strip()
|
|
563
|
+
|
|
564
|
+
# print(f" -> Evaluating label: '{name}'")
|
|
505
565
|
true_i = y_true[:, i]
|
|
506
566
|
pred_i = y_pred[:, i] # Use passed-in y_pred
|
|
507
567
|
prob_i = y_prob[:, i] # Use passed-in y_prob
|
|
508
568
|
sanitized_name = sanitize_filename(name)
|
|
569
|
+
|
|
570
|
+
# if name is too long, just take the first letter of each word. Each word might be separated by space or underscore
|
|
571
|
+
if len(name) >= _EvaluationConfig.NAME_LIMIT:
|
|
572
|
+
parts = [w for w in name.replace("_", " ").split() if w]
|
|
573
|
+
abbr = "".join(p[0].upper() for p in parts)
|
|
574
|
+
# keep only alpha numeric chars
|
|
575
|
+
abbr = "".join(ch for ch in abbr if ch.isalnum())
|
|
576
|
+
if not abbr:
|
|
577
|
+
# fallback to a sanitized, truncated version of the original name
|
|
578
|
+
abbr = sanitize_filename(name)[: _EvaluationConfig.NAME_LIMIT]
|
|
579
|
+
_LOGGER.warning(f"Using abbreviated name '{abbr}' for '{name}' plots.")
|
|
580
|
+
name = abbr
|
|
509
581
|
|
|
510
582
|
# --- Save Classification Report for the label (uses y_pred) ---
|
|
511
583
|
report_text = classification_report(true_i, pred_i)
|
|
@@ -537,7 +609,7 @@ def multi_label_classification_metrics(
|
|
|
537
609
|
ax_cm.tick_params(axis='y', labelsize=ytick_size)
|
|
538
610
|
|
|
539
611
|
# Set titles and labels with padding
|
|
540
|
-
ax_cm.set_title(f"Confusion Matrix
|
|
612
|
+
ax_cm.set_title(f"Confusion Matrix - {name}", pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
|
|
541
613
|
ax_cm.set_xlabel(ax_cm.get_xlabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
542
614
|
ax_cm.set_ylabel(ax_cm.get_ylabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
543
615
|
|
|
@@ -594,7 +666,7 @@ def multi_label_classification_metrics(
|
|
|
594
666
|
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=format_config.ROC_PR_line) # Use config color
|
|
595
667
|
ax_roc.plot([0, 1], [0, 1], 'k--')
|
|
596
668
|
|
|
597
|
-
ax_roc.set_title(f'ROC Curve
|
|
669
|
+
ax_roc.set_title(f'ROC Curve - {name}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
|
|
598
670
|
ax_roc.set_xlabel('False Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
599
671
|
ax_roc.set_ylabel('True Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
600
672
|
|
|
@@ -616,7 +688,7 @@ def multi_label_classification_metrics(
|
|
|
616
688
|
ap_score = average_precision_score(true_i, prob_i)
|
|
617
689
|
fig_pr, ax_pr = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
|
|
618
690
|
ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=format_config.ROC_PR_line) # Use config color
|
|
619
|
-
ax_pr.set_title(f'
|
|
691
|
+
ax_pr.set_title(f'PR Curve - {name}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
|
|
620
692
|
ax_pr.set_xlabel('Recall', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
621
693
|
ax_pr.set_ylabel('Precision', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
622
694
|
|
|
@@ -659,10 +731,10 @@ def multi_label_classification_metrics(
|
|
|
659
731
|
prob_true,
|
|
660
732
|
marker='o',
|
|
661
733
|
linewidth=2,
|
|
662
|
-
label=f"Calibration
|
|
734
|
+
label=f"Model Calibration",
|
|
663
735
|
color=format_config.ROC_PR_line)
|
|
664
736
|
|
|
665
|
-
ax_cal.set_title(f'
|
|
737
|
+
ax_cal.set_title(f'Calibration - {name}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
|
|
666
738
|
ax_cal.set_xlabel('Mean Predicted Probability', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
667
739
|
ax_cal.set_ylabel('Fraction of Positives', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
668
740
|
|
|
@@ -2,6 +2,7 @@ from ._analysis import (
|
|
|
2
2
|
summarize_dataframe,
|
|
3
3
|
show_null_columns,
|
|
4
4
|
match_and_filter_columns_by_regex,
|
|
5
|
+
check_class_balance,
|
|
5
6
|
)
|
|
6
7
|
|
|
7
8
|
from ._cleaning import (
|
|
@@ -28,6 +29,7 @@ from ._features import (
|
|
|
28
29
|
split_continuous_binary,
|
|
29
30
|
split_continuous_categorical_targets,
|
|
30
31
|
encode_categorical_features,
|
|
32
|
+
encode_classification_target,
|
|
31
33
|
reconstruct_one_hot,
|
|
32
34
|
reconstruct_binary,
|
|
33
35
|
reconstruct_multibinary,
|
|
@@ -44,7 +46,6 @@ from .._core import _imprimir_disponibles
|
|
|
44
46
|
|
|
45
47
|
__all__ = [
|
|
46
48
|
"summarize_dataframe",
|
|
47
|
-
"show_null_columns",
|
|
48
49
|
"drop_constant_columns",
|
|
49
50
|
"drop_rows_with_missing_data",
|
|
50
51
|
"drop_columns_with_missing_data",
|
|
@@ -61,10 +62,13 @@ __all__ = [
|
|
|
61
62
|
"plot_categorical_vs_target",
|
|
62
63
|
"plot_correlation_heatmap",
|
|
63
64
|
"encode_categorical_features",
|
|
65
|
+
"encode_classification_target",
|
|
64
66
|
"finalize_feature_schema",
|
|
65
67
|
"apply_feature_schema",
|
|
66
68
|
"reconstruct_from_schema",
|
|
67
69
|
"match_and_filter_columns_by_regex",
|
|
70
|
+
"show_null_columns",
|
|
71
|
+
"check_class_balance",
|
|
68
72
|
"standardize_percentages",
|
|
69
73
|
"reconstruct_one_hot",
|
|
70
74
|
"reconstruct_binary",
|
|
@@ -16,6 +16,7 @@ __all__ = [
|
|
|
16
16
|
"summarize_dataframe",
|
|
17
17
|
"show_null_columns",
|
|
18
18
|
"match_and_filter_columns_by_regex",
|
|
19
|
+
"check_class_balance",
|
|
19
20
|
]
|
|
20
21
|
|
|
21
22
|
|
|
@@ -212,3 +213,151 @@ def match_and_filter_columns_by_regex(
|
|
|
212
213
|
|
|
213
214
|
return filtered_df, matched_columns
|
|
214
215
|
|
|
216
|
+
|
|
217
|
+
def check_class_balance(
|
|
218
|
+
df: pd.DataFrame,
|
|
219
|
+
target: Union[str, list[str]],
|
|
220
|
+
plot_to_dir: Optional[Union[str, Path]] = None,
|
|
221
|
+
plot_filename: str = "Class_Balance"
|
|
222
|
+
) -> pd.DataFrame:
|
|
223
|
+
"""
|
|
224
|
+
Analyzes the class balance for classification targets.
|
|
225
|
+
|
|
226
|
+
Handles two cases:
|
|
227
|
+
1. Single Column (Binary/Multi-class): Calculates frequency of each unique value.
|
|
228
|
+
2. List of Columns (Multi-label Binary): Calculates the frequency of positive values (1) per column.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
df (pd.DataFrame): The input DataFrame.
|
|
232
|
+
target (str | list[str]): The target column name (for single/multi-class classification)
|
|
233
|
+
or list of column names (for multi-label-binary classification).
|
|
234
|
+
plot_to_dir (str | Path | None): Directory to save the balance plot.
|
|
235
|
+
plot_filename (str): Filename for the plot (without extension).
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
pd.DataFrame: Summary table of counts and percentages.
|
|
239
|
+
"""
|
|
240
|
+
# Early fail for empty DataFrame and handle list of targets with only one item
|
|
241
|
+
if df.empty:
|
|
242
|
+
_LOGGER.error("Input DataFrame is empty.")
|
|
243
|
+
raise ValueError()
|
|
244
|
+
|
|
245
|
+
if isinstance(target, list):
|
|
246
|
+
if len(target) == 0:
|
|
247
|
+
_LOGGER.error("Target list is empty.")
|
|
248
|
+
raise ValueError()
|
|
249
|
+
elif len(target) == 1:
|
|
250
|
+
target = target[0] # Simplify to single column case
|
|
251
|
+
|
|
252
|
+
# Case 1: Single Target (Binary or Multi-class)
|
|
253
|
+
if isinstance(target, str):
|
|
254
|
+
if target not in df.columns:
|
|
255
|
+
_LOGGER.error(f"Target column '{target}' not found in DataFrame.")
|
|
256
|
+
raise ValueError()
|
|
257
|
+
|
|
258
|
+
# Calculate stats
|
|
259
|
+
counts = df[target].value_counts(dropna=False).sort_index()
|
|
260
|
+
percents = df[target].value_counts(normalize=True, dropna=False).sort_index() * 100
|
|
261
|
+
|
|
262
|
+
summary = pd.DataFrame({
|
|
263
|
+
'Count': counts,
|
|
264
|
+
'Percentage': percents.round(2)
|
|
265
|
+
})
|
|
266
|
+
summary.index.name = "Class"
|
|
267
|
+
|
|
268
|
+
# Plotting
|
|
269
|
+
if plot_to_dir:
|
|
270
|
+
try:
|
|
271
|
+
save_path = make_fullpath(plot_to_dir, make=True, enforce="directory")
|
|
272
|
+
|
|
273
|
+
plt.figure(figsize=(10, 6))
|
|
274
|
+
# Convert index to str to handle numeric classes cleanly on x-axis
|
|
275
|
+
x_labels = summary.index.astype(str)
|
|
276
|
+
bars = plt.bar(x_labels, summary['Count'], color='lightgreen', edgecolor='black', alpha=0.7)
|
|
277
|
+
|
|
278
|
+
plt.title(f"Class Balance: {target}")
|
|
279
|
+
plt.xlabel(target)
|
|
280
|
+
plt.ylabel("Count")
|
|
281
|
+
plt.grid(axis='y', linestyle='--', alpha=0.5)
|
|
282
|
+
|
|
283
|
+
# Add percentage labels on top of bars
|
|
284
|
+
for bar, pct in zip(bars, summary['Percentage']):
|
|
285
|
+
height = bar.get_height()
|
|
286
|
+
plt.text(bar.get_x() + bar.get_width()/2, height,
|
|
287
|
+
f'{pct:.1f}%', ha='center', va='bottom', fontsize=10)
|
|
288
|
+
|
|
289
|
+
plt.tight_layout()
|
|
290
|
+
full_filename = sanitize_filename(plot_filename) + ".svg"
|
|
291
|
+
plt.savefig(save_path / full_filename, format='svg', bbox_inches="tight")
|
|
292
|
+
plt.close()
|
|
293
|
+
_LOGGER.info(f"Saved class balance plot: '{full_filename}'")
|
|
294
|
+
except Exception as e:
|
|
295
|
+
_LOGGER.error(f"Failed to plot class balance. Error: {e}")
|
|
296
|
+
plt.close()
|
|
297
|
+
|
|
298
|
+
return summary
|
|
299
|
+
|
|
300
|
+
# Case 2: Multi-label (List of binary columns)
|
|
301
|
+
elif isinstance(target, list):
|
|
302
|
+
missing_cols = [t for t in target if t not in df.columns]
|
|
303
|
+
if missing_cols:
|
|
304
|
+
_LOGGER.error(f"Target columns not found: {missing_cols}")
|
|
305
|
+
raise ValueError()
|
|
306
|
+
|
|
307
|
+
stats = []
|
|
308
|
+
for col in target:
|
|
309
|
+
# Assume 0/1 or False/True. Sum gives the count of positives.
|
|
310
|
+
# We enforce numeric to be safe
|
|
311
|
+
try:
|
|
312
|
+
numeric_series = pd.to_numeric(df[col], errors='coerce').fillna(0)
|
|
313
|
+
pos_count = numeric_series.sum()
|
|
314
|
+
total_count = len(df)
|
|
315
|
+
pct = (pos_count / total_count) * 100
|
|
316
|
+
except Exception:
|
|
317
|
+
_LOGGER.warning(f"Column '{col}' could not be processed as numeric. Assuming 0 positives.")
|
|
318
|
+
pos_count = 0
|
|
319
|
+
pct = 0.0
|
|
320
|
+
|
|
321
|
+
stats.append({
|
|
322
|
+
'Label': col,
|
|
323
|
+
'Positive_Count': int(pos_count),
|
|
324
|
+
'Positive_Percentage': round(pct, 2)
|
|
325
|
+
})
|
|
326
|
+
|
|
327
|
+
summary = pd.DataFrame(stats).set_index("Label").sort_values("Positive_Percentage", ascending=True)
|
|
328
|
+
|
|
329
|
+
# Plotting
|
|
330
|
+
if plot_to_dir:
|
|
331
|
+
try:
|
|
332
|
+
save_path = make_fullpath(plot_to_dir, make=True, enforce="directory")
|
|
333
|
+
|
|
334
|
+
# Dynamic height for many labels
|
|
335
|
+
height = max(6, len(target) * 0.4)
|
|
336
|
+
plt.figure(figsize=(10, height))
|
|
337
|
+
|
|
338
|
+
bars = plt.barh(summary.index, summary['Positive_Percentage'], color='lightgreen', edgecolor='black', alpha=0.7)
|
|
339
|
+
|
|
340
|
+
plt.title(f"Multi-label Binary Class Balance")
|
|
341
|
+
plt.xlabel("Positive Class Percentage (%)")
|
|
342
|
+
plt.xlim(0, 100)
|
|
343
|
+
plt.grid(axis='x', linestyle='--', alpha=0.5)
|
|
344
|
+
|
|
345
|
+
# Add count labels at the end of bars
|
|
346
|
+
for bar, count in zip(bars, summary['Positive_Count']):
|
|
347
|
+
width = bar.get_width()
|
|
348
|
+
plt.text(width + 1, bar.get_y() + bar.get_height()/2, f'{width:.1f}%', ha='left', va='center', fontsize=9)
|
|
349
|
+
|
|
350
|
+
plt.tight_layout()
|
|
351
|
+
full_filename = sanitize_filename(plot_filename) + ".svg"
|
|
352
|
+
plt.savefig(save_path / full_filename, format='svg', bbox_inches="tight")
|
|
353
|
+
plt.close()
|
|
354
|
+
_LOGGER.info(f"Saved multi-label balance plot: '{full_filename}'")
|
|
355
|
+
except Exception as e:
|
|
356
|
+
_LOGGER.error(f"Failed to plot class balance. Error: {e}")
|
|
357
|
+
plt.close()
|
|
358
|
+
|
|
359
|
+
return summary.sort_values("Positive_Percentage", ascending=False)
|
|
360
|
+
|
|
361
|
+
else:
|
|
362
|
+
_LOGGER.error("Target must be a string or a list of strings.")
|
|
363
|
+
raise TypeError()
|
|
@@ -3,7 +3,10 @@ from pandas.api.types import is_numeric_dtype, is_object_dtype
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
from typing import Any, Optional, Union
|
|
5
5
|
import re
|
|
6
|
+
import json
|
|
7
|
+
from pathlib import Path
|
|
6
8
|
|
|
9
|
+
from ..path_manager import make_fullpath
|
|
7
10
|
from .._core import get_logger
|
|
8
11
|
|
|
9
12
|
|
|
@@ -15,6 +18,7 @@ __all__ = [
|
|
|
15
18
|
"split_continuous_binary",
|
|
16
19
|
"split_continuous_categorical_targets",
|
|
17
20
|
"encode_categorical_features",
|
|
21
|
+
"encode_classification_target",
|
|
18
22
|
"reconstruct_one_hot",
|
|
19
23
|
"reconstruct_binary",
|
|
20
24
|
"reconstruct_multibinary",
|
|
@@ -263,6 +267,78 @@ def encode_categorical_features(
|
|
|
263
267
|
return df_encoded, mappings
|
|
264
268
|
|
|
265
269
|
|
|
270
|
+
def encode_classification_target(
|
|
271
|
+
df: pd.DataFrame,
|
|
272
|
+
target_col: str,
|
|
273
|
+
save_dir: Union[str, Path],
|
|
274
|
+
verbose: int = 2
|
|
275
|
+
) -> tuple[pd.DataFrame, dict[str, int]]:
|
|
276
|
+
"""
|
|
277
|
+
Encodes a target classification column into integers (0, 1, 2...) and saves the mapping to a JSON file.
|
|
278
|
+
|
|
279
|
+
This ensures that the target variable is in the correct numeric format for training
|
|
280
|
+
and provides a persistent artifact (the JSON file) to map predictions back to labels later.
|
|
281
|
+
|
|
282
|
+
Args:
|
|
283
|
+
df (pd.DataFrame): Input DataFrame.
|
|
284
|
+
target_col (str): Name of the target column to encode.
|
|
285
|
+
save_dir (str | Path): Directory where the class map JSON will be saved.
|
|
286
|
+
verbose (int): Verbosity level for logging.
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
Tuple (Dataframe, Dict):
|
|
290
|
+
- A new DataFrame with the target column encoded as integers.
|
|
291
|
+
- The dictionary mapping original labels (str) to integers (int).
|
|
292
|
+
"""
|
|
293
|
+
if target_col not in df.columns:
|
|
294
|
+
_LOGGER.error(f"Target column '{target_col}' not found in DataFrame.")
|
|
295
|
+
raise ValueError()
|
|
296
|
+
|
|
297
|
+
# Validation: Check for missing values in target
|
|
298
|
+
if df[target_col].isnull().any():
|
|
299
|
+
n_missing = df[target_col].isnull().sum()
|
|
300
|
+
_LOGGER.error(f"Target column '{target_col}' contains {n_missing} missing values. Please handle them before encoding.")
|
|
301
|
+
raise ValueError()
|
|
302
|
+
|
|
303
|
+
# Ensure directory exists
|
|
304
|
+
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
305
|
+
file_path = save_path / "class_map.json"
|
|
306
|
+
|
|
307
|
+
# Get unique values and sort them to ensure deterministic encoding (0, 1, 2...)
|
|
308
|
+
# Convert to string to ensure the keys in JSON are strings
|
|
309
|
+
unique_labels = sorted(df[target_col].astype(str).unique())
|
|
310
|
+
|
|
311
|
+
# Create mapping: { Label -> Integer }
|
|
312
|
+
class_map = {label: idx for idx, label in enumerate(unique_labels)}
|
|
313
|
+
|
|
314
|
+
# Apply mapping
|
|
315
|
+
# cast column to string to match the keys in class_map
|
|
316
|
+
df_encoded = df.copy()
|
|
317
|
+
df_encoded[target_col] = df_encoded[target_col].astype(str).map(class_map)
|
|
318
|
+
|
|
319
|
+
# Save to JSON
|
|
320
|
+
try:
|
|
321
|
+
with open(file_path, 'w', encoding='utf-8') as f:
|
|
322
|
+
json.dump(class_map, f, indent=4)
|
|
323
|
+
|
|
324
|
+
if verbose >= 2:
|
|
325
|
+
_LOGGER.info(f"Class mapping saved to: '{file_path}'")
|
|
326
|
+
|
|
327
|
+
if verbose >= 3:
|
|
328
|
+
_LOGGER.info(f"Target '{target_col}' encoded with {len(class_map)} classes.")
|
|
329
|
+
# Print a preview
|
|
330
|
+
if len(class_map) <= 10:
|
|
331
|
+
print(f" Mapping: {class_map}")
|
|
332
|
+
else:
|
|
333
|
+
print(f" Mapping (first 5): {dict(list(class_map.items())[:5])} ...")
|
|
334
|
+
|
|
335
|
+
except Exception as e:
|
|
336
|
+
_LOGGER.error(f"Failed to save class map JSON. Error: {e}")
|
|
337
|
+
raise IOError()
|
|
338
|
+
|
|
339
|
+
return df_encoded, class_map
|
|
340
|
+
|
|
341
|
+
|
|
266
342
|
def reconstruct_one_hot(
|
|
267
343
|
df: pd.DataFrame,
|
|
268
344
|
features_to_reconstruct: list[Union[str, tuple[str, Optional[str]]]],
|
ml_tools/keys/_keys.py
CHANGED
|
@@ -306,6 +306,7 @@ class _EvaluationConfig:
|
|
|
306
306
|
LOSS_PLOT_LEGEND_SIZE = 24
|
|
307
307
|
# CM settings
|
|
308
308
|
CM_SIZE = (9, 8) # used for multi label binary classification confusion matrix
|
|
309
|
+
NAME_LIMIT = 20 # max number of characters for feature/label names in plots
|
|
309
310
|
|
|
310
311
|
class _OneHotOtherPlaceholder:
|
|
311
312
|
"""Used internally by GUI_tools."""
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from ._single_resampling import (
|
|
2
|
+
DragonResampler,
|
|
3
|
+
)
|
|
4
|
+
|
|
5
|
+
from ._multi_resampling import (
|
|
6
|
+
DragonMultiResampler,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
from .._core import _imprimir_disponibles
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"DragonResampler",
|
|
14
|
+
"DragonMultiResampler",
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def info():
|
|
19
|
+
_imprimir_disponibles(__all__)
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import polars as pl
|
|
2
|
+
import pandas as pd
|
|
3
|
+
from typing import Union
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
__all__ = ["_DragonBaseResampler"]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class _DragonBaseResampler(ABC):
|
|
11
|
+
"""
|
|
12
|
+
Base class for Dragon resamplers handling common I/O and state.
|
|
13
|
+
"""
|
|
14
|
+
def __init__(self,
|
|
15
|
+
return_pandas: bool = False,
|
|
16
|
+
seed: int = 42):
|
|
17
|
+
self.return_pandas = return_pandas
|
|
18
|
+
self.seed = seed
|
|
19
|
+
|
|
20
|
+
def _convert_to_polars(self, df: Union[pd.DataFrame, pl.DataFrame]) -> pl.DataFrame:
|
|
21
|
+
"""Standardizes input to Polars DataFrame."""
|
|
22
|
+
if isinstance(df, pd.DataFrame):
|
|
23
|
+
return pl.from_pandas(df)
|
|
24
|
+
return df
|
|
25
|
+
|
|
26
|
+
def _convert_to_pandas(self, df: pl.DataFrame) -> pd.DataFrame:
|
|
27
|
+
"""Converts Polars DataFrame back to Pandas."""
|
|
28
|
+
return df.to_pandas(use_pyarrow_extension_array=False)
|
|
29
|
+
|
|
30
|
+
def _process_return(self, df: pl.DataFrame, shuffle: bool = True) -> Union[pd.DataFrame, pl.DataFrame]:
|
|
31
|
+
"""
|
|
32
|
+
Finalizes the DataFrame:
|
|
33
|
+
1. Global Shuffle (optional but recommended for ML).
|
|
34
|
+
2. Conversion to Pandas (if requested).
|
|
35
|
+
"""
|
|
36
|
+
if shuffle:
|
|
37
|
+
# Random shuffle of the final dataset
|
|
38
|
+
df = df.sample(fraction=1.0, seed=self.seed, with_replacement=False)
|
|
39
|
+
|
|
40
|
+
if self.return_pandas:
|
|
41
|
+
return self._convert_to_pandas(df)
|
|
42
|
+
return df
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def describe_balance(self, df: Union[pd.DataFrame, pl.DataFrame], top_n: int = 10) -> None:
|
|
46
|
+
"""
|
|
47
|
+
Prints a statistical summary of the target distribution.
|
|
48
|
+
"""
|
|
49
|
+
pass
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
import polars as pl
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import numpy as np
|
|
4
|
+
from typing import Union, Optional
|
|
5
|
+
|
|
6
|
+
from .._core import get_logger
|
|
7
|
+
|
|
8
|
+
from ._base_resampler import _DragonBaseResampler
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
_LOGGER = get_logger("DragonMultiResampler")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"DragonMultiResampler",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DragonMultiResampler(_DragonBaseResampler):
|
|
20
|
+
"""
|
|
21
|
+
A robust resampler for multi-label binary classification tasks using Polars.
|
|
22
|
+
|
|
23
|
+
It provides methods to downsample "all-negative" rows and balance the dataset
|
|
24
|
+
based on unique label combinations (Powerset).
|
|
25
|
+
"""
|
|
26
|
+
def __init__(self,
|
|
27
|
+
target_columns: list[str],
|
|
28
|
+
return_pandas: bool = False,
|
|
29
|
+
seed: int = 42):
|
|
30
|
+
"""
|
|
31
|
+
Args:
|
|
32
|
+
target_columns (List[str]): The list of binary target column names.
|
|
33
|
+
return_pandas (bool): Whether to return results as pandas DataFrame.
|
|
34
|
+
seed (int): Random seed for reproducibility.
|
|
35
|
+
"""
|
|
36
|
+
super().__init__(return_pandas=return_pandas, seed=seed)
|
|
37
|
+
self.targets = target_columns
|
|
38
|
+
|
|
39
|
+
def downsample_all_negatives(self,
|
|
40
|
+
df: Union[pd.DataFrame, pl.DataFrame],
|
|
41
|
+
negative_ratio: float = 1.0,
|
|
42
|
+
verbose: int = 2) -> Union[pd.DataFrame, pl.DataFrame]:
|
|
43
|
+
"""
|
|
44
|
+
Downsamples rows where ALL target columns are 0 ("background" class).
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
df (pd.DataFrame | pl.DataFrame): Input DataFrame.
|
|
48
|
+
negative_ratio (float): Ratio of negatives to positives to retain.
|
|
49
|
+
verbose (int): Verbosity level for logging.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
Dataframe: Resampled DataFrame.
|
|
53
|
+
"""
|
|
54
|
+
df_pl = self._convert_to_polars(df)
|
|
55
|
+
|
|
56
|
+
# 1. Identify "All Negative" vs "Has Signal"
|
|
57
|
+
fold_expr = pl.sum_horizontal(pl.col(self.targets)).cast(pl.UInt32)
|
|
58
|
+
|
|
59
|
+
df_pos = df_pl.filter(fold_expr > 0)
|
|
60
|
+
df_neg = df_pl.filter(fold_expr == 0)
|
|
61
|
+
|
|
62
|
+
n_pos = df_pos.height
|
|
63
|
+
n_neg_original = df_neg.height
|
|
64
|
+
|
|
65
|
+
if n_pos == 0:
|
|
66
|
+
if verbose >= 1:
|
|
67
|
+
_LOGGER.warning("No positive cases found in any label. Returning original DataFrame.")
|
|
68
|
+
return self._process_return(df_pl, shuffle=False)
|
|
69
|
+
|
|
70
|
+
# 2. Calculate target count for negatives
|
|
71
|
+
target_n_neg = int(n_pos * negative_ratio)
|
|
72
|
+
|
|
73
|
+
# 3. Sample if necessary
|
|
74
|
+
if n_neg_original > target_n_neg:
|
|
75
|
+
if verbose >= 2:
|
|
76
|
+
_LOGGER.info(f"📉 Downsampling 'All-Negative' rows from {n_neg_original} to {target_n_neg}")
|
|
77
|
+
|
|
78
|
+
# Here we use standard sampling because we are not grouping
|
|
79
|
+
df_neg_sampled = df_neg.sample(n=target_n_neg, seed=self.seed, with_replacement=False)
|
|
80
|
+
df_resampled = pl.concat([df_pos, df_neg_sampled])
|
|
81
|
+
|
|
82
|
+
return self._process_return(df_resampled)
|
|
83
|
+
else:
|
|
84
|
+
if verbose >= 1:
|
|
85
|
+
_LOGGER.warning(f"Negative count ({n_neg_original}) is already below target ({target_n_neg}). No downsampling applied.")
|
|
86
|
+
return self._process_return(df_pl, shuffle=False)
|
|
87
|
+
|
|
88
|
+
def balance_powerset(self,
|
|
89
|
+
df: Union[pd.DataFrame, pl.DataFrame],
|
|
90
|
+
max_samples_per_combination: Optional[int] = None,
|
|
91
|
+
quantile_limit: float = 0.90,
|
|
92
|
+
verbose: int = 2) -> Union[pd.DataFrame, pl.DataFrame]:
|
|
93
|
+
"""
|
|
94
|
+
Groups data by unique label combinations (Powerset) and downsamples
|
|
95
|
+
majority combinations.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
df (pd.DataFrame | pl.DataFrame): Input DataFrame.
|
|
99
|
+
max_samples_per_combination (int | None): Fixed cap per combination.
|
|
100
|
+
If None, uses quantile_limit to determine cap.
|
|
101
|
+
quantile_limit (float): Quantile to determine cap if max_samples_per_combination is None.
|
|
102
|
+
verbose (int): Verbosity level for logging.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
Dataframe: Resampled DataFrame.
|
|
106
|
+
"""
|
|
107
|
+
df_pl = self._convert_to_polars(df)
|
|
108
|
+
|
|
109
|
+
# 1. Create a hash/structural representation of the targets for grouping
|
|
110
|
+
df_lazy = df_pl.lazy().with_columns(
|
|
111
|
+
pl.concat_list(pl.col(self.targets)).alias("_powerset_key")
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# 2. Calculate frequencies
|
|
115
|
+
# We need to collect partially to calculate the quantile cap
|
|
116
|
+
combo_counts = df_lazy.group_by("_powerset_key").len().collect()
|
|
117
|
+
|
|
118
|
+
# Determine the Cap
|
|
119
|
+
if max_samples_per_combination is None:
|
|
120
|
+
# Handle potential None from quantile (satisfies linter)
|
|
121
|
+
q_val = combo_counts["len"].quantile(quantile_limit)
|
|
122
|
+
|
|
123
|
+
if q_val is None:
|
|
124
|
+
if verbose >= 1:
|
|
125
|
+
_LOGGER.warning("Data empty or insufficient to calculate quantile. Returning original.")
|
|
126
|
+
return self._process_return(df_pl, shuffle=False)
|
|
127
|
+
|
|
128
|
+
cap_size = int(q_val)
|
|
129
|
+
|
|
130
|
+
if verbose >= 3:
|
|
131
|
+
_LOGGER.info(f"📊 Auto-calculated Powerset Cap: {cap_size} samples (based on {quantile_limit} quantile).")
|
|
132
|
+
else:
|
|
133
|
+
cap_size = max_samples_per_combination
|
|
134
|
+
|
|
135
|
+
# 3. Apply Stratified Sampling / Capping (Randomized)
|
|
136
|
+
df_balanced = (
|
|
137
|
+
df_lazy
|
|
138
|
+
.filter(
|
|
139
|
+
pl.int_range(0, pl.len())
|
|
140
|
+
.shuffle(seed=self.seed)
|
|
141
|
+
.over("_powerset_key")
|
|
142
|
+
< cap_size
|
|
143
|
+
)
|
|
144
|
+
.drop("_powerset_key")
|
|
145
|
+
.collect()
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
if verbose >= 2:
|
|
149
|
+
original_count = df_pl.height
|
|
150
|
+
new_count = df_balanced.height
|
|
151
|
+
_LOGGER.info(f"⚖️ Powerset Balancing: Reduced from {original_count} to {new_count} rows.")
|
|
152
|
+
|
|
153
|
+
return self._process_return(df_balanced)
|
|
154
|
+
|
|
155
|
+
def describe_balance(self, df: Union[pd.DataFrame, pl.DataFrame], top_n: int = 10) -> None:
|
|
156
|
+
df_pl = self._convert_to_polars(df)
|
|
157
|
+
total_rows = df_pl.height
|
|
158
|
+
|
|
159
|
+
message_1 = f"\n📊 --- Target Balance Report ({total_rows} samples) ---\n🎯 Multi-Targets: {len(self.targets)} columns"
|
|
160
|
+
|
|
161
|
+
# A. Individual Label Counts
|
|
162
|
+
sums = df_pl.select([
|
|
163
|
+
pl.sum(col).alias(col) for col in self.targets
|
|
164
|
+
]).transpose(include_header=True, header_name="Label", column_names=["Count"])
|
|
165
|
+
|
|
166
|
+
sums = sums.with_columns(
|
|
167
|
+
(pl.col("Count") / total_rows * 100).round(2).alias("Percentage(%)")
|
|
168
|
+
).sort("Count", descending=True)
|
|
169
|
+
|
|
170
|
+
message_1 += "\n🔹 Individual Label Frequencies:"
|
|
171
|
+
|
|
172
|
+
# B. Powerset (Combination) Counts
|
|
173
|
+
message_2 = f"🔹 Top {top_n} Label Combinations (Powerset):"
|
|
174
|
+
|
|
175
|
+
combo_stats = (
|
|
176
|
+
df_pl.group_by(self.targets)
|
|
177
|
+
.len(name="Count")
|
|
178
|
+
.sort("Count", descending=True)
|
|
179
|
+
.with_columns(
|
|
180
|
+
(pl.col("Count") / total_rows * 100).round(2).alias("Percentage(%)")
|
|
181
|
+
)
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
_LOGGER.info(f"{message_1}\n{sums.head(top_n)}\n{message_2}\n{combo_stats.head(top_n)}")
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
import polars as pl
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import numpy as np
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
from .._core import get_logger
|
|
7
|
+
|
|
8
|
+
from ._base_resampler import _DragonBaseResampler
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
_LOGGER = get_logger("DragonResampler")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"DragonResampler",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DragonResampler(_DragonBaseResampler):
|
|
20
|
+
"""
|
|
21
|
+
A resampler for Single-Target Classification tasks (Binary or Multiclass).
|
|
22
|
+
|
|
23
|
+
It balances classes by downsampling majority classes relative to the size of the minority class.
|
|
24
|
+
"""
|
|
25
|
+
def __init__(self,
|
|
26
|
+
target_column: str,
|
|
27
|
+
return_pandas: bool = False,
|
|
28
|
+
seed: int = 42):
|
|
29
|
+
"""
|
|
30
|
+
Args:
|
|
31
|
+
target_column (str): The name of the single target column.
|
|
32
|
+
return_pandas (bool): Whether to return results as pandas DataFrame.
|
|
33
|
+
seed (int): Random seed for reproducibility.
|
|
34
|
+
"""
|
|
35
|
+
super().__init__(return_pandas=return_pandas, seed=seed)
|
|
36
|
+
self.target = target_column
|
|
37
|
+
|
|
38
|
+
def balance_classes(self,
|
|
39
|
+
df: Union[pd.DataFrame, pl.DataFrame],
|
|
40
|
+
majority_ratio: float = 1.0,
|
|
41
|
+
verbose: int = 2) -> Union[pd.DataFrame, pl.DataFrame]:
|
|
42
|
+
"""
|
|
43
|
+
Downsamples all classes to match the minority class size (scaled by a ratio).
|
|
44
|
+
"""
|
|
45
|
+
df_pl = self._convert_to_polars(df)
|
|
46
|
+
|
|
47
|
+
# 1. Calculate Class Counts
|
|
48
|
+
counts = df_pl.group_by(self.target).len().sort("len")
|
|
49
|
+
|
|
50
|
+
if counts.height == 0:
|
|
51
|
+
_LOGGER.error("DataFrame is empty or target column missing.")
|
|
52
|
+
return self._process_return(df_pl, shuffle=False)
|
|
53
|
+
|
|
54
|
+
# 2. Identify Statistics
|
|
55
|
+
min_val = counts["len"].min()
|
|
56
|
+
max_val = counts["len"].max()
|
|
57
|
+
|
|
58
|
+
if min_val is None or max_val is None:
|
|
59
|
+
_LOGGER.error("Failed to calculate class statistics (unexpected None).")
|
|
60
|
+
raise ValueError()
|
|
61
|
+
|
|
62
|
+
minority_count: int = min_val # type: ignore
|
|
63
|
+
majority_count: int = max_val # type: ignore
|
|
64
|
+
|
|
65
|
+
# Calculate the cap
|
|
66
|
+
cap_size = int(minority_count * majority_ratio)
|
|
67
|
+
|
|
68
|
+
if verbose >= 3:
|
|
69
|
+
_LOGGER.info(f"📊 Class Distribution:\n{counts}")
|
|
70
|
+
_LOGGER.info(f"🎯 Strategy: Cap majorities at {cap_size}")
|
|
71
|
+
|
|
72
|
+
# Optimization: If data is already balanced enough
|
|
73
|
+
if majority_count <= cap_size:
|
|
74
|
+
if verbose >= 2:
|
|
75
|
+
_LOGGER.info("Data is already within the requested balance ratio.")
|
|
76
|
+
return self._process_return(df_pl, shuffle=False)
|
|
77
|
+
|
|
78
|
+
# 3. Apply Downsampling (Randomized)
|
|
79
|
+
# We generate a random range index per group and filter by it.
|
|
80
|
+
# This ensures we pick a random subset, not the first N rows.
|
|
81
|
+
df_balanced = (
|
|
82
|
+
df_pl.lazy()
|
|
83
|
+
.filter(
|
|
84
|
+
pl.int_range(0, pl.len())
|
|
85
|
+
.shuffle(seed=self.seed)
|
|
86
|
+
.over(self.target)
|
|
87
|
+
< cap_size
|
|
88
|
+
)
|
|
89
|
+
.collect()
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
if verbose >= 2:
|
|
93
|
+
reduced_count = df_balanced.height
|
|
94
|
+
_LOGGER.info(f"⚖️ Balancing Complete: {df_pl.height} -> {reduced_count} rows.")
|
|
95
|
+
|
|
96
|
+
return self._process_return(df_balanced)
|
|
97
|
+
|
|
98
|
+
def describe_balance(self, df: Union[pd.DataFrame, pl.DataFrame], top_n: int = 10) -> None:
|
|
99
|
+
df_pl = self._convert_to_polars(df)
|
|
100
|
+
total_rows = df_pl.height
|
|
101
|
+
|
|
102
|
+
message = f"\n📊 --- Target Balance Report ({total_rows} samples) ---\n🎯 Single Target: '{self.target}'"
|
|
103
|
+
|
|
104
|
+
stats = (
|
|
105
|
+
df_pl.group_by(self.target)
|
|
106
|
+
.len(name="Count")
|
|
107
|
+
.sort("Count", descending=True)
|
|
108
|
+
.with_columns(
|
|
109
|
+
(pl.col("Count") / total_rows * 100).round(2).alias("Percentage(%)")
|
|
110
|
+
)
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
_LOGGER.info(f"{message}\n{stats.head(top_n)}")
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|