dragon-ml-toolbox 20.7.1__py3-none-any.whl → 20.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-20.7.1.dist-info → dragon_ml_toolbox-20.9.0.dist-info}/METADATA +3 -1
- {dragon_ml_toolbox-20.7.1.dist-info → dragon_ml_toolbox-20.9.0.dist-info}/RECORD +19 -14
- ml_tools/ML_configuration/_metrics.py +17 -10
- ml_tools/ML_datasetmaster/_base_datasetmaster.py +53 -3
- ml_tools/ML_evaluation/_classification.py +83 -8
- ml_tools/ML_evaluation/_helpers.py +41 -0
- ml_tools/ML_evaluation/_regression.py +5 -0
- ml_tools/data_exploration/__init__.py +5 -1
- ml_tools/data_exploration/_analysis.py +149 -0
- ml_tools/data_exploration/_features.py +76 -0
- ml_tools/keys/_keys.py +1 -0
- ml_tools/resampling/__init__.py +19 -0
- ml_tools/resampling/_base_resampler.py +49 -0
- ml_tools/resampling/_multi_resampling.py +184 -0
- ml_tools/resampling/_single_resampling.py +113 -0
- {dragon_ml_toolbox-20.7.1.dist-info → dragon_ml_toolbox-20.9.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-20.7.1.dist-info → dragon_ml_toolbox-20.9.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-20.7.1.dist-info → dragon_ml_toolbox-20.9.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-20.7.1.dist-info → dragon_ml_toolbox-20.9.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dragon-ml-toolbox
|
|
3
|
-
Version: 20.
|
|
3
|
+
Version: 20.9.0
|
|
4
4
|
Summary: Complete pipelines and helper tools for data science and machine learning projects.
|
|
5
5
|
Author-email: Karl Luigi Loza Vidaurre <luigiloza@gmail.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -174,6 +174,7 @@ ML_vision_transformers
|
|
|
174
174
|
optimization_tools
|
|
175
175
|
path_manager
|
|
176
176
|
plot_fonts
|
|
177
|
+
resampling
|
|
177
178
|
schema
|
|
178
179
|
serde
|
|
179
180
|
SQL
|
|
@@ -206,6 +207,7 @@ optimization_tools
|
|
|
206
207
|
path_manager
|
|
207
208
|
plot_fonts
|
|
208
209
|
PSO_optimization
|
|
210
|
+
resampling
|
|
209
211
|
schema
|
|
210
212
|
serde
|
|
211
213
|
SQL
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
dragon_ml_toolbox-20.
|
|
2
|
-
dragon_ml_toolbox-20.
|
|
1
|
+
dragon_ml_toolbox-20.9.0.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
|
|
2
|
+
dragon_ml_toolbox-20.9.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=0-HBRMMgKuwtGy6nMJZvIn1fLxhx_ksyyVB2U_iyYZU,2818
|
|
3
3
|
ml_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
4
|
ml_tools/constants.py,sha256=3br5Rk9cL2IUo638eJuMOGdbGQaWssaUecYEvSeRBLM,3322
|
|
5
5
|
ml_tools/ETL_cleaning/__init__.py,sha256=gLRHF-qzwpqKTvbbn9chIQELeUDh_XGpBRX28j-5IqI,545
|
|
@@ -30,19 +30,20 @@ ml_tools/ML_chain/_update_schema.py,sha256=z1Us7lv6hy6GwSu1mcid50Jmqq3sh91hMQ0Ln
|
|
|
30
30
|
ml_tools/ML_configuration/__init__.py,sha256=ogktFnYxz5jWJkhHS4DVaMldHkt3lT2gw9jx5PQ3d78,2755
|
|
31
31
|
ml_tools/ML_configuration/_base_model_config.py,sha256=95L3IfobNFMtnNr79zYpDGerC1q1v7M05tWZvTS2cwE,2247
|
|
32
32
|
ml_tools/ML_configuration/_finalize.py,sha256=l_n13bLu0avMdJ8hNRrH8V_wOBQZM1UGsTydKBkTysM,15047
|
|
33
|
-
ml_tools/ML_configuration/_metrics.py,sha256=
|
|
33
|
+
ml_tools/ML_configuration/_metrics.py,sha256=KJM7HQeoEmJUUUrxNa4wYf2N9NawGPJoy7AGdNO3gxQ,24059
|
|
34
34
|
ml_tools/ML_configuration/_models.py,sha256=lvuuqvD6DWUzOa3i06NZfrdfOi9bu2e26T_QO6BGMSw,7629
|
|
35
35
|
ml_tools/ML_configuration/_training.py,sha256=_M_TwouHFNbGrZQtQNAvyG_poSVpmN99cbyUonZsHhk,8969
|
|
36
36
|
ml_tools/ML_datasetmaster/__init__.py,sha256=UltQzuXnlXVCkD-aeA5TW4IcMVLnQf1_aglawg4WyrI,580
|
|
37
|
-
ml_tools/ML_datasetmaster/_base_datasetmaster.py,sha256=
|
|
37
|
+
ml_tools/ML_datasetmaster/_base_datasetmaster.py,sha256=IgyVzRY3mlKDyBDklawvPF9SMjZFu8T2red6M-3MlQ4,16074
|
|
38
38
|
ml_tools/ML_datasetmaster/_datasetmaster.py,sha256=Oy2UE3YJpKTaFwQF5TkQLgLB54-BFw_5b8wIPTxZIKU,19157
|
|
39
39
|
ml_tools/ML_datasetmaster/_sequence_datasetmaster.py,sha256=cW3fuILZWs-7Yuo4T2fgGfTC4vwho3Gp4ohIKJYS7O0,18452
|
|
40
40
|
ml_tools/ML_datasetmaster/_vision_datasetmaster.py,sha256=kvSqXYeNBN1JSRfSEEXYeIcsqy9HsJAl_EwFWClqlsw,67025
|
|
41
41
|
ml_tools/ML_evaluation/__init__.py,sha256=e3c8JNP0tt4Kxc7QSQpGcOgrxf8JAucH4UkJvJxUL2E,1122
|
|
42
|
-
ml_tools/ML_evaluation/_classification.py,sha256=
|
|
42
|
+
ml_tools/ML_evaluation/_classification.py,sha256=0URqIhNEgWedy-SYRmIJ2ejLKqatiuOU7qelJ6Cv3OE,33939
|
|
43
43
|
ml_tools/ML_evaluation/_feature_importance.py,sha256=mTwi3LKom_axu6UFKunELj30APDdhG9GQC2w7I9mYhI,17137
|
|
44
|
+
ml_tools/ML_evaluation/_helpers.py,sha256=kE1TSYIOAAcYI1EjdudyTfFeU47Wrl0E9eNL1EOwbKg,1217
|
|
44
45
|
ml_tools/ML_evaluation/_loss.py,sha256=1a4O25i3Ya_3naNZNL7ELLUL46BY86g1scA7d7q2UFM,3625
|
|
45
|
-
ml_tools/ML_evaluation/_regression.py,sha256=
|
|
46
|
+
ml_tools/ML_evaluation/_regression.py,sha256=UZA7_fg85ZKJQWszioWDtmkplSiXeHJk2fBYR5bRXHY,11225
|
|
46
47
|
ml_tools/ML_evaluation/_sequence.py,sha256=gUk9Uvmy7MrXkfrriMnfypkgJU5XERHdqekTa2gBaOM,8004
|
|
47
48
|
ml_tools/ML_evaluation/_vision.py,sha256=abBHQ6Z2GunHNusL3wcLgfI1FVNA6hBUBTq1eOA8FSA,11489
|
|
48
49
|
ml_tools/ML_evaluation_captum/_ML_evaluation_captum.py,sha256=6g3ymSxJGHXxwIN7WCD2Zi9zxKWEv-Qskd2cCGQQJ5Y,18439
|
|
@@ -103,10 +104,10 @@ ml_tools/_core/__init__.py,sha256=m-VP0RW0tOTm9N5NI3kFNcpM7WtVgs0RK9pK3ZJRZQQ,14
|
|
|
103
104
|
ml_tools/_core/_logger.py,sha256=xzhn_FouMDRVNwXGBGlPC9Ruq6i5uCrmNaS5jesguMU,4972
|
|
104
105
|
ml_tools/_core/_schema_load_ops.py,sha256=KLs9vBzANz5ESe2wlP-C41N4VlgGil-ywcfvWKSOGss,1551
|
|
105
106
|
ml_tools/_core/_script_info.py,sha256=LtFGt10gEvCnhIRMKJPi2yXkiGLcdr7lE-oIP2XGHzQ,234
|
|
106
|
-
ml_tools/data_exploration/__init__.py,sha256=
|
|
107
|
-
ml_tools/data_exploration/_analysis.py,sha256=
|
|
107
|
+
ml_tools/data_exploration/__init__.py,sha256=efUBsruHL56B429tUadl3PdG73zAF639Y430uMQRfko,1917
|
|
108
|
+
ml_tools/data_exploration/_analysis.py,sha256=PJNrEBz5ZZXHoUlQ6fh9Y86nzPQrLpVPv2Ye4NfOxgs,14181
|
|
108
109
|
ml_tools/data_exploration/_cleaning.py,sha256=pAZOXgGK35j7O8q6cnyTwYK1GLNnD04A8p2fSyMB1mg,20906
|
|
109
|
-
ml_tools/data_exploration/_features.py,sha256=
|
|
110
|
+
ml_tools/data_exploration/_features.py,sha256=Z1noJfDxBzFRfusFp6NlpLF2NItuZuzFHq4ssWFqny4,26273
|
|
110
111
|
ml_tools/data_exploration/_plotting.py,sha256=zH1dPcIoAlOuww23xIoBCsQOAshPPv9OyGposOA2RvI,19883
|
|
111
112
|
ml_tools/data_exploration/_schema_ops.py,sha256=Fd6fBGGv4OpxmJ1HG9pith6QL90z0tzssCvzkQxlEEQ,11083
|
|
112
113
|
ml_tools/ensemble_evaluation/__init__.py,sha256=t4Gr8EGEk8RLatyc92-S0BzbQvdvodzoF-qDAH2qjVg,546
|
|
@@ -118,7 +119,7 @@ ml_tools/ensemble_learning/_ensemble_learning.py,sha256=MHDZBR20_nStlSSeThFI3bSu
|
|
|
118
119
|
ml_tools/excel_handler/__init__.py,sha256=AaWM3n_dqBhJLTs3OEA57ex5YykKXNOwVCyHlVsdnqI,530
|
|
119
120
|
ml_tools/excel_handler/_excel_handler.py,sha256=TODudmeQgDSdxUKzLfAzizs--VL-g8WxDOfQ4sgxxLs,13965
|
|
120
121
|
ml_tools/keys/__init__.py,sha256=-0c2pmrhyfROc-oQpEjJGLBMhSagA3CyFijQaaqZRqU,399
|
|
121
|
-
ml_tools/keys/_keys.py,sha256=
|
|
122
|
+
ml_tools/keys/_keys.py,sha256=56hlyPl2VUMsq7cFFLBypWHr-JU6ehWGwZG38l6IjI0,9389
|
|
122
123
|
ml_tools/math_utilities/__init__.py,sha256=K7Obkkc4rPKj4EbRZf1BsXHfiCg7FXYv_aN9Yc2Z_Vg,400
|
|
123
124
|
ml_tools/math_utilities/_math_utilities.py,sha256=BYHIVcM9tuKIhVrkgLLiM5QalJ39zx7dXYy_M9aGgiM,9012
|
|
124
125
|
ml_tools/optimization_tools/__init__.py,sha256=KD8JXpfGuPndO4AHnjJGu6uV1GRwhOfboD0KZV45kzw,658
|
|
@@ -129,6 +130,10 @@ ml_tools/path_manager/_dragonmanager.py,sha256=q9wHTKPmdzywEz6N14ipUoeR3MmW0bzB4
|
|
|
129
130
|
ml_tools/path_manager/_path_tools.py,sha256=LcZE31QlkzZWUR8g1MW_N_mPY2DpKBJLA45VJz7ZYsw,11905
|
|
130
131
|
ml_tools/plot_fonts/__init__.py,sha256=KIxXRCjQ3SliEoLhEcqs7zDVZbVTn38bmSdL-yR1Q2w,187
|
|
131
132
|
ml_tools/plot_fonts/_plot_fonts.py,sha256=mfjXNT9P59ymHoTI85Q8CcvfxfK5BIFBWtTZH-hNIC4,2209
|
|
133
|
+
ml_tools/resampling/__init__.py,sha256=WB1YlNQgOIdSSQn-9eCIaiB0AHLSCkziFufqa-1QBG0,278
|
|
134
|
+
ml_tools/resampling/_base_resampler.py,sha256=8IqkEJ7uiAiC9bqbKfsC-5vIvrN3EwH7lLVDlRKQzM8,1617
|
|
135
|
+
ml_tools/resampling/_multi_resampling.py,sha256=m_iVvXPAu3p_EoBt2VZpgjhPLY1LmKa8fGtQo5E0pWk,7199
|
|
136
|
+
ml_tools/resampling/_single_resampling.py,sha256=zKL4Br7Lm4Jq90X-ewQ6AKTsP923bq9RIMnTxIxtXBc,3896
|
|
132
137
|
ml_tools/schema/__init__.py,sha256=K6uiZ9f0GCQ7etw1yl2-dQVLhU7RkL3KHesO3HNX6v4,334
|
|
133
138
|
ml_tools/schema/_feature_schema.py,sha256=MuPf6Nf7tDhUTGyX7tcFHZh-lLSNsJkLmlf9IxdF4O4,9660
|
|
134
139
|
ml_tools/schema/_gui_schema.py,sha256=IVwN4THAdFrvh2TpV4SFd_zlzMX3eioF-w-qcSVTndE,7245
|
|
@@ -138,7 +143,7 @@ ml_tools/utilities/__init__.py,sha256=h4lE3SQstg-opcQj6QSKhu-HkqSbmHExsWoM9vC5D9
|
|
|
138
143
|
ml_tools/utilities/_translate.py,sha256=U8hRPa3PmTpIf9n9yR3gBGmp_hkcsjQLwjAHSHc0WHs,10325
|
|
139
144
|
ml_tools/utilities/_utility_save_load.py,sha256=EFvFaTaHahDQWdJWZr-j7cHqRbG_Xrpc96228JhV-bs,16773
|
|
140
145
|
ml_tools/utilities/_utility_tools.py,sha256=bN0J9d1S0W5wNzNntBWqDsJcEAK7-1OgQg3X2fwXns0,6918
|
|
141
|
-
dragon_ml_toolbox-20.
|
|
142
|
-
dragon_ml_toolbox-20.
|
|
143
|
-
dragon_ml_toolbox-20.
|
|
144
|
-
dragon_ml_toolbox-20.
|
|
146
|
+
dragon_ml_toolbox-20.9.0.dist-info/METADATA,sha256=ehKhp6BpCkHcZnWpcoZU53rn4T0yI0Dboq3eH2vx8LU,7888
|
|
147
|
+
dragon_ml_toolbox-20.9.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
148
|
+
dragon_ml_toolbox-20.9.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
|
|
149
|
+
dragon_ml_toolbox-20.9.0.dist-info/RECORD,,
|
|
@@ -98,10 +98,11 @@ class _BaseMultiLabelFormat:
|
|
|
98
98
|
cmap: str = "BuGn",
|
|
99
99
|
ROC_PR_line: str='darkorange',
|
|
100
100
|
calibration_bins: Union[int, Literal['auto']]='auto',
|
|
101
|
-
font_size: int =
|
|
102
|
-
xtick_size: int=
|
|
103
|
-
ytick_size: int=
|
|
104
|
-
legend_size: int=
|
|
101
|
+
font_size: int = 26,
|
|
102
|
+
xtick_size: int=22,
|
|
103
|
+
ytick_size: int=22,
|
|
104
|
+
legend_size: int=26,
|
|
105
|
+
cm_font_size: int=26) -> None:
|
|
105
106
|
"""
|
|
106
107
|
Initializes the formatting configuration for multi-label classification metrics.
|
|
107
108
|
|
|
@@ -127,6 +128,8 @@ class _BaseMultiLabelFormat:
|
|
|
127
128
|
|
|
128
129
|
legend_size (int): Font size for plot legends.
|
|
129
130
|
|
|
131
|
+
cm_font_size (int): Font size for the confusion matrix.
|
|
132
|
+
|
|
130
133
|
<br>
|
|
131
134
|
|
|
132
135
|
### [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
|
|
@@ -142,6 +145,7 @@ class _BaseMultiLabelFormat:
|
|
|
142
145
|
self.xtick_size = xtick_size
|
|
143
146
|
self.ytick_size = ytick_size
|
|
144
147
|
self.legend_size = legend_size
|
|
148
|
+
self.cm_font_size = cm_font_size
|
|
145
149
|
|
|
146
150
|
def __repr__(self) -> str:
|
|
147
151
|
parts = [
|
|
@@ -151,7 +155,8 @@ class _BaseMultiLabelFormat:
|
|
|
151
155
|
f"font_size={self.font_size}",
|
|
152
156
|
f"xtick_size={self.xtick_size}",
|
|
153
157
|
f"ytick_size={self.ytick_size}",
|
|
154
|
-
f"legend_size={self.legend_size}"
|
|
158
|
+
f"legend_size={self.legend_size}",
|
|
159
|
+
f"cm_font_size={self.cm_font_size}"
|
|
155
160
|
]
|
|
156
161
|
return f"{self.__class__.__name__}({', '.join(parts)})"
|
|
157
162
|
|
|
@@ -520,10 +525,11 @@ class FormatMultiLabelBinaryClassificationMetrics(_BaseMultiLabelFormat):
|
|
|
520
525
|
cmap: str = "BuGn",
|
|
521
526
|
ROC_PR_line: str='darkorange',
|
|
522
527
|
calibration_bins: Union[int, Literal['auto']]='auto',
|
|
523
|
-
font_size: int =
|
|
524
|
-
xtick_size: int=
|
|
525
|
-
ytick_size: int=
|
|
526
|
-
legend_size: int=
|
|
528
|
+
font_size: int = 26,
|
|
529
|
+
xtick_size: int=22,
|
|
530
|
+
ytick_size: int=22,
|
|
531
|
+
legend_size: int=26,
|
|
532
|
+
cm_font_size: int=26
|
|
527
533
|
) -> None:
|
|
528
534
|
super().__init__(cmap=cmap,
|
|
529
535
|
ROC_PR_line=ROC_PR_line,
|
|
@@ -531,7 +537,8 @@ class FormatMultiLabelBinaryClassificationMetrics(_BaseMultiLabelFormat):
|
|
|
531
537
|
font_size=font_size,
|
|
532
538
|
xtick_size=xtick_size,
|
|
533
539
|
ytick_size=ytick_size,
|
|
534
|
-
legend_size=legend_size
|
|
540
|
+
legend_size=legend_size,
|
|
541
|
+
cm_font_size=cm_font_size)
|
|
535
542
|
|
|
536
543
|
|
|
537
544
|
# Segmentation
|
|
@@ -133,7 +133,7 @@ class _BaseDatasetMaker(ABC):
|
|
|
133
133
|
|
|
134
134
|
# Get continuous feature indices *from the schema*
|
|
135
135
|
if schema.continuous_feature_names:
|
|
136
|
-
if verbose >=
|
|
136
|
+
if verbose >= 3:
|
|
137
137
|
_LOGGER.info("Getting continuous feature indices from schema.")
|
|
138
138
|
try:
|
|
139
139
|
# Convert columns to a standard list for .index()
|
|
@@ -189,7 +189,7 @@ class _BaseDatasetMaker(ABC):
|
|
|
189
189
|
# ------------------------------------------------------------------
|
|
190
190
|
|
|
191
191
|
if self.target_scaler is None:
|
|
192
|
-
if verbose >=
|
|
192
|
+
if verbose >= 3:
|
|
193
193
|
_LOGGER.info("Fitting a new DragonScaler on training targets.")
|
|
194
194
|
# Convert to float tensor for calculation
|
|
195
195
|
y_train_tensor = torch.tensor(y_train_arr, dtype=torch.float32)
|
|
@@ -202,6 +202,9 @@ class _BaseDatasetMaker(ABC):
|
|
|
202
202
|
y_val_tensor = self.target_scaler.transform(torch.tensor(y_val_arr, dtype=torch.float32))
|
|
203
203
|
y_test_tensor = self.target_scaler.transform(torch.tensor(y_test_arr, dtype=torch.float32))
|
|
204
204
|
return y_train_tensor.numpy(), y_val_tensor.numpy(), y_test_tensor.numpy()
|
|
205
|
+
|
|
206
|
+
if verbose >= 2:
|
|
207
|
+
_LOGGER.info("Target scaling transformation complete.")
|
|
205
208
|
|
|
206
209
|
return y_train_arr, y_val_arr, y_test_arr
|
|
207
210
|
|
|
@@ -214,6 +217,9 @@ class _BaseDatasetMaker(ABC):
|
|
|
214
217
|
|
|
215
218
|
@property
|
|
216
219
|
def train_dataset(self) -> Dataset:
|
|
220
|
+
"""
|
|
221
|
+
Returns the training dataset.
|
|
222
|
+
"""
|
|
217
223
|
if self._train_ds is None:
|
|
218
224
|
_LOGGER.error("Train Dataset not yet created.")
|
|
219
225
|
raise RuntimeError()
|
|
@@ -221,6 +227,9 @@ class _BaseDatasetMaker(ABC):
|
|
|
221
227
|
|
|
222
228
|
@property
|
|
223
229
|
def validation_dataset(self) -> Dataset:
|
|
230
|
+
"""
|
|
231
|
+
Returns the validation dataset.
|
|
232
|
+
"""
|
|
224
233
|
if self._val_ds is None:
|
|
225
234
|
_LOGGER.error("Validation Dataset not yet created.")
|
|
226
235
|
raise RuntimeError()
|
|
@@ -228,6 +237,9 @@ class _BaseDatasetMaker(ABC):
|
|
|
228
237
|
|
|
229
238
|
@property
|
|
230
239
|
def test_dataset(self) -> Dataset:
|
|
240
|
+
"""
|
|
241
|
+
Returns the test dataset.
|
|
242
|
+
"""
|
|
231
243
|
if self._test_ds is None:
|
|
232
244
|
_LOGGER.error("Test Dataset not yet created.")
|
|
233
245
|
raise RuntimeError()
|
|
@@ -235,30 +247,50 @@ class _BaseDatasetMaker(ABC):
|
|
|
235
247
|
|
|
236
248
|
@property
|
|
237
249
|
def feature_names(self) -> list[str]:
|
|
250
|
+
"""
|
|
251
|
+
Returns a list with the feature names.
|
|
252
|
+
"""
|
|
238
253
|
return self._feature_names
|
|
239
254
|
|
|
240
255
|
@property
|
|
241
256
|
def target_names(self) -> list[str]:
|
|
257
|
+
"""
|
|
258
|
+
Returns a list with the target names.
|
|
259
|
+
"""
|
|
242
260
|
return self._target_names
|
|
243
261
|
|
|
244
262
|
@property
|
|
245
263
|
def number_of_features(self) -> int:
|
|
264
|
+
"""
|
|
265
|
+
Returns the number of features.
|
|
266
|
+
"""
|
|
246
267
|
return len(self._feature_names)
|
|
247
268
|
|
|
248
269
|
@property
|
|
249
270
|
def number_of_targets(self) -> int:
|
|
271
|
+
"""
|
|
272
|
+
Returns the number of targets.
|
|
273
|
+
"""
|
|
250
274
|
return len(self._target_names)
|
|
251
275
|
|
|
252
276
|
@property
|
|
253
277
|
def id(self) -> Optional[str]:
|
|
278
|
+
"""
|
|
279
|
+
Returns the dataset ID if set, otherwise None.
|
|
280
|
+
"""
|
|
254
281
|
return self._id
|
|
255
282
|
|
|
256
283
|
@id.setter
|
|
257
284
|
def id(self, dataset_id: str):
|
|
258
|
-
if not isinstance(dataset_id, str):
|
|
285
|
+
if not isinstance(dataset_id, str):
|
|
286
|
+
_LOGGER.error("Dataset ID must be a string.")
|
|
287
|
+
raise ValueError()
|
|
259
288
|
self._id = dataset_id
|
|
260
289
|
|
|
261
290
|
def dataframes_info(self) -> None:
|
|
291
|
+
"""
|
|
292
|
+
Prints the shapes of the dataframes after the split.
|
|
293
|
+
"""
|
|
262
294
|
print("--- DataFrame Shapes After Split ---")
|
|
263
295
|
print(f" X_train shape: {self._X_train_shape}, y_train shape: {self._y_train_shape}")
|
|
264
296
|
print(f" X_val shape: {self._X_val_shape}, y_val shape: {self._y_val_shape}")
|
|
@@ -266,12 +298,26 @@ class _BaseDatasetMaker(ABC):
|
|
|
266
298
|
print("------------------------------------")
|
|
267
299
|
|
|
268
300
|
def save_feature_names(self, directory: Union[str, Path], verbose: bool=True) -> None:
|
|
301
|
+
"""
|
|
302
|
+
Saves the feature names to a text file.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
directory (str | Path): Directory to save the feature names.
|
|
306
|
+
verbose (bool): Whether to print log messages.
|
|
307
|
+
"""
|
|
269
308
|
save_list_strings(list_strings=self._feature_names,
|
|
270
309
|
directory=directory,
|
|
271
310
|
filename=DatasetKeys.FEATURE_NAMES,
|
|
272
311
|
verbose=verbose)
|
|
273
312
|
|
|
274
313
|
def save_target_names(self, directory: Union[str, Path], verbose: bool=True) -> None:
|
|
314
|
+
"""
|
|
315
|
+
Saves the target names to a text file.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
directory (str | Path): Directory to save the target names.
|
|
319
|
+
verbose (bool): Whether to print log messages.
|
|
320
|
+
"""
|
|
275
321
|
save_list_strings(list_strings=self._target_names,
|
|
276
322
|
directory=directory,
|
|
277
323
|
filename=DatasetKeys.TARGET_NAMES,
|
|
@@ -281,6 +327,10 @@ class _BaseDatasetMaker(ABC):
|
|
|
281
327
|
"""
|
|
282
328
|
Saves both feature and target scalers (if they exist) to a single .pth file
|
|
283
329
|
using a dictionary structure.
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
directory (str | Path): Directory to save the scaler.
|
|
333
|
+
verbose (bool): Whether to print log messages.
|
|
284
334
|
"""
|
|
285
335
|
if self.feature_scaler is None and self.target_scaler is None:
|
|
286
336
|
_LOGGER.warning("No scalers (feature or target) were fitted. Nothing to save.")
|
|
@@ -28,6 +28,8 @@ from ..path_manager import make_fullpath, sanitize_filename
|
|
|
28
28
|
from .._core import get_logger
|
|
29
29
|
from ..keys._keys import _EvaluationConfig
|
|
30
30
|
|
|
31
|
+
from ._helpers import check_and_abbreviate_name
|
|
32
|
+
|
|
31
33
|
|
|
32
34
|
_LOGGER = get_logger("Classification Metrics")
|
|
33
35
|
|
|
@@ -85,7 +87,8 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
85
87
|
try:
|
|
86
88
|
sorted_items = sorted(class_map.items(), key=lambda item: item[1])
|
|
87
89
|
map_labels = [item[1] for item in sorted_items]
|
|
88
|
-
|
|
90
|
+
# Abbreviate display labels if needed
|
|
91
|
+
map_display_labels = [check_and_abbreviate_name(item[0]) for item in sorted_items]
|
|
89
92
|
except Exception as e:
|
|
90
93
|
_LOGGER.warning(f"Could not parse 'class_map': {e}")
|
|
91
94
|
map_labels = None
|
|
@@ -397,6 +400,10 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
397
400
|
# --- Step 1: Get binned data directly ---
|
|
398
401
|
# calculates reliability diagram data without needing a temporary plot
|
|
399
402
|
prob_true, prob_pred = calibration_curve(y_true_binary, y_score, n_bins=dynamic_bins)
|
|
403
|
+
|
|
404
|
+
# Anchor the plot to (0,0) and (1,1) to ensure the line spans the full diagonal
|
|
405
|
+
prob_true = np.concatenate(([0.0], prob_true, [1.0]))
|
|
406
|
+
prob_pred = np.concatenate(([0.0], prob_pred, [1.0]))
|
|
400
407
|
|
|
401
408
|
# --- Step 2: Plot ---
|
|
402
409
|
ax_cal.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')
|
|
@@ -467,6 +474,9 @@ def multi_label_classification_metrics(
|
|
|
467
474
|
|
|
468
475
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
469
476
|
|
|
477
|
+
# --- Pre-process target names for abbreviation ---
|
|
478
|
+
target_names = [check_and_abbreviate_name(name) for name in target_names]
|
|
479
|
+
|
|
470
480
|
# --- Parse Config or use defaults ---
|
|
471
481
|
if config is None:
|
|
472
482
|
# Create a default config if one wasn't provided
|
|
@@ -481,6 +491,10 @@ def multi_label_classification_metrics(
|
|
|
481
491
|
ytick_size = format_config.ytick_size
|
|
482
492
|
legend_size = format_config.legend_size
|
|
483
493
|
base_font_size = format_config.font_size
|
|
494
|
+
|
|
495
|
+
# config font size for heatmap
|
|
496
|
+
cm_font_size = format_config.cm_font_size
|
|
497
|
+
cm_tick_size = cm_font_size - 4
|
|
484
498
|
|
|
485
499
|
# --- Calculate and Save Overall Metrics (using y_pred) ---
|
|
486
500
|
h_loss = hamming_loss(y_true, y_pred)
|
|
@@ -488,7 +502,7 @@ def multi_label_classification_metrics(
|
|
|
488
502
|
j_score_macro = jaccard_score(y_true, y_pred, average='macro')
|
|
489
503
|
|
|
490
504
|
overall_report = (
|
|
491
|
-
f"Overall Multi-Label Metrics:\n"
|
|
505
|
+
f"Overall Multi-Label Metrics:\n"
|
|
492
506
|
f"--------------------------------------------------\n"
|
|
493
507
|
f"Hamming Loss: {h_loss:.4f}\n"
|
|
494
508
|
f"Jaccard Score (micro): {j_score_micro:.4f}\n"
|
|
@@ -499,9 +513,65 @@ def multi_label_classification_metrics(
|
|
|
499
513
|
overall_report_path = save_dir_path / "classification_report.txt"
|
|
500
514
|
overall_report_path.write_text(overall_report)
|
|
501
515
|
|
|
516
|
+
# --- Save Classification Report Heatmap (Multi-label) ---
|
|
517
|
+
try:
|
|
518
|
+
# Generate full report as dict
|
|
519
|
+
full_report_dict = classification_report(y_true, y_pred, target_names=target_names, output_dict=True)
|
|
520
|
+
report_df = pd.DataFrame(full_report_dict)
|
|
521
|
+
|
|
522
|
+
# Cleanup
|
|
523
|
+
# Remove 'accuracy' column if it exists
|
|
524
|
+
report_df = report_df.drop(columns=['accuracy'], errors='ignore')
|
|
525
|
+
|
|
526
|
+
# Remove 'support' row explicitly
|
|
527
|
+
if 'support' in report_df.index:
|
|
528
|
+
report_df = report_df.drop(index='support')
|
|
529
|
+
|
|
530
|
+
# Transpose: Rows = Classes/Averages, Cols = Metrics
|
|
531
|
+
plot_df = report_df.T
|
|
532
|
+
|
|
533
|
+
# Dynamic Height
|
|
534
|
+
fig_height = max(5.0, len(plot_df.index) * 0.5 + 4.0)
|
|
535
|
+
fig_width = 8.0
|
|
536
|
+
|
|
537
|
+
fig_heat, ax_heat = plt.subplots(figsize=(fig_width, fig_height), dpi=_EvaluationConfig.DPI)
|
|
538
|
+
|
|
539
|
+
# Plot
|
|
540
|
+
sns.heatmap(plot_df,
|
|
541
|
+
annot=True,
|
|
542
|
+
cmap=format_config.cmap,
|
|
543
|
+
fmt='.2f',
|
|
544
|
+
vmin=0.0,
|
|
545
|
+
vmax=1.0,
|
|
546
|
+
cbar_kws={'shrink': 0.9})
|
|
547
|
+
|
|
548
|
+
ax_heat.set_title("Classification Report Heatmap", pad=_EvaluationConfig.LABEL_PADDING, fontsize=cm_font_size)
|
|
549
|
+
|
|
550
|
+
# manually increase the font size of the elements
|
|
551
|
+
for text in ax_heat.texts:
|
|
552
|
+
text.set_fontsize(cm_tick_size)
|
|
553
|
+
|
|
554
|
+
cbar = ax_heat.collections[0].colorbar
|
|
555
|
+
cbar.ax.tick_params(labelsize=cm_tick_size - 4) # type: ignore
|
|
556
|
+
|
|
557
|
+
ax_heat.tick_params(axis='x', labelsize=cm_tick_size, pad=_EvaluationConfig.LABEL_PADDING)
|
|
558
|
+
ax_heat.tick_params(axis='y', labelsize=cm_tick_size, pad=_EvaluationConfig.LABEL_PADDING, rotation=0)
|
|
559
|
+
|
|
560
|
+
plt.tight_layout()
|
|
561
|
+
heatmap_path = save_dir_path / "classification_report_heatmap.svg"
|
|
562
|
+
plt.savefig(heatmap_path)
|
|
563
|
+
_LOGGER.info(f"📊 Report heatmap saved as '{heatmap_path.name}'")
|
|
564
|
+
plt.close(fig_heat)
|
|
565
|
+
|
|
566
|
+
except Exception as e:
|
|
567
|
+
_LOGGER.error(f"Could not generate multi-label classification report heatmap: {e}")
|
|
568
|
+
|
|
502
569
|
# --- Per-Label Metrics and Plots ---
|
|
503
570
|
for i, name in enumerate(target_names):
|
|
504
|
-
|
|
571
|
+
# strip whitespace from name
|
|
572
|
+
name = name.strip()
|
|
573
|
+
|
|
574
|
+
# print(f" -> Evaluating label: '{name}'")
|
|
505
575
|
true_i = y_true[:, i]
|
|
506
576
|
pred_i = y_pred[:, i] # Use passed-in y_pred
|
|
507
577
|
prob_i = y_prob[:, i] # Use passed-in y_prob
|
|
@@ -537,7 +607,7 @@ def multi_label_classification_metrics(
|
|
|
537
607
|
ax_cm.tick_params(axis='y', labelsize=ytick_size)
|
|
538
608
|
|
|
539
609
|
# Set titles and labels with padding
|
|
540
|
-
ax_cm.set_title(f"Confusion Matrix
|
|
610
|
+
ax_cm.set_title(f"Confusion Matrix - {name}", pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
|
|
541
611
|
ax_cm.set_xlabel(ax_cm.get_xlabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
542
612
|
ax_cm.set_ylabel(ax_cm.get_ylabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
543
613
|
|
|
@@ -594,7 +664,7 @@ def multi_label_classification_metrics(
|
|
|
594
664
|
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=format_config.ROC_PR_line) # Use config color
|
|
595
665
|
ax_roc.plot([0, 1], [0, 1], 'k--')
|
|
596
666
|
|
|
597
|
-
ax_roc.set_title(f'ROC Curve
|
|
667
|
+
ax_roc.set_title(f'ROC Curve - {name}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
|
|
598
668
|
ax_roc.set_xlabel('False Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
599
669
|
ax_roc.set_ylabel('True Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
600
670
|
|
|
@@ -616,7 +686,7 @@ def multi_label_classification_metrics(
|
|
|
616
686
|
ap_score = average_precision_score(true_i, prob_i)
|
|
617
687
|
fig_pr, ax_pr = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
|
|
618
688
|
ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=format_config.ROC_PR_line) # Use config color
|
|
619
|
-
ax_pr.set_title(f'
|
|
689
|
+
ax_pr.set_title(f'PR Curve - {name}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
|
|
620
690
|
ax_pr.set_xlabel('Recall', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
621
691
|
ax_pr.set_ylabel('Precision', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
622
692
|
|
|
@@ -654,15 +724,20 @@ def multi_label_classification_metrics(
|
|
|
654
724
|
# Calculate calibration curve for this specific label
|
|
655
725
|
prob_true, prob_pred = calibration_curve(true_i, prob_i, n_bins=dynamic_bins)
|
|
656
726
|
|
|
727
|
+
# Anchor the plot to (0,0) and (1,1)
|
|
728
|
+
prob_true = np.concatenate(([0.0], prob_true, [1.0]))
|
|
729
|
+
prob_pred = np.concatenate(([0.0], prob_pred, [1.0]))
|
|
730
|
+
|
|
731
|
+
# Plot the calibration curve
|
|
657
732
|
ax_cal.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')
|
|
658
733
|
ax_cal.plot(prob_pred,
|
|
659
734
|
prob_true,
|
|
660
735
|
marker='o',
|
|
661
736
|
linewidth=2,
|
|
662
|
-
label=f"Calibration
|
|
737
|
+
label=f"Model Calibration",
|
|
663
738
|
color=format_config.ROC_PR_line)
|
|
664
739
|
|
|
665
|
-
ax_cal.set_title(f'
|
|
740
|
+
ax_cal.set_title(f'Calibration - {name}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
|
|
666
741
|
ax_cal.set_xlabel('Mean Predicted Probability', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
667
742
|
ax_cal.set_ylabel('Fraction of Positives', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
668
743
|
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from ..keys._keys import _EvaluationConfig
|
|
2
|
+
from ..path_manager import sanitize_filename
|
|
3
|
+
from .._core import get_logger
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
_LOGGER = get_logger("Metrics Helper")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def check_and_abbreviate_name(name: str) -> str:
|
|
10
|
+
"""
|
|
11
|
+
Checks if a name exceeds the NAME_LIMIT. If it does, creates an abbreviation
|
|
12
|
+
(initials of words) or truncates it if the abbreviation is empty.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
name (str): The original label or target name.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
str: The potentially abbreviated name.
|
|
19
|
+
"""
|
|
20
|
+
limit = _EvaluationConfig.NAME_LIMIT
|
|
21
|
+
|
|
22
|
+
# Strip whitespace
|
|
23
|
+
name = name.strip()
|
|
24
|
+
|
|
25
|
+
if len(name) <= limit:
|
|
26
|
+
return name
|
|
27
|
+
|
|
28
|
+
# Attempt abbreviation: First letter of each word (split by space or underscore)
|
|
29
|
+
parts = [w for w in name.replace("_", " ").split() if w]
|
|
30
|
+
abbr = "".join(p[0].upper() for p in parts)
|
|
31
|
+
|
|
32
|
+
# Keep only alphanumeric characters
|
|
33
|
+
abbr = "".join(ch for ch in abbr if ch.isalnum())
|
|
34
|
+
|
|
35
|
+
# Fallback if abbreviation failed or is empty
|
|
36
|
+
if not abbr:
|
|
37
|
+
sanitized = sanitize_filename(name)
|
|
38
|
+
abbr = sanitized[:limit]
|
|
39
|
+
|
|
40
|
+
_LOGGER.warning(f"Label '{name}' is too long. Abbreviating to '{abbr}'.")
|
|
41
|
+
return abbr
|
|
@@ -19,6 +19,8 @@ from ..path_manager import make_fullpath, sanitize_filename
|
|
|
19
19
|
from .._core import get_logger
|
|
20
20
|
from ..keys._keys import _EvaluationConfig
|
|
21
21
|
|
|
22
|
+
from ._helpers import check_and_abbreviate_name
|
|
23
|
+
|
|
22
24
|
|
|
23
25
|
_LOGGER = get_logger("Regression Metrics")
|
|
24
26
|
|
|
@@ -180,6 +182,9 @@ def multi_target_regression_metrics(
|
|
|
180
182
|
if y_true.shape[1] != len(target_names):
|
|
181
183
|
_LOGGER.error("Number of target names must match the number of columns in y_true.")
|
|
182
184
|
raise ValueError()
|
|
185
|
+
|
|
186
|
+
# --- Pre-process target names for abbreviation ---
|
|
187
|
+
target_names = [check_and_abbreviate_name(name) for name in target_names]
|
|
183
188
|
|
|
184
189
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
185
190
|
metrics_summary = []
|
|
@@ -2,6 +2,7 @@ from ._analysis import (
|
|
|
2
2
|
summarize_dataframe,
|
|
3
3
|
show_null_columns,
|
|
4
4
|
match_and_filter_columns_by_regex,
|
|
5
|
+
check_class_balance,
|
|
5
6
|
)
|
|
6
7
|
|
|
7
8
|
from ._cleaning import (
|
|
@@ -28,6 +29,7 @@ from ._features import (
|
|
|
28
29
|
split_continuous_binary,
|
|
29
30
|
split_continuous_categorical_targets,
|
|
30
31
|
encode_categorical_features,
|
|
32
|
+
encode_classification_target,
|
|
31
33
|
reconstruct_one_hot,
|
|
32
34
|
reconstruct_binary,
|
|
33
35
|
reconstruct_multibinary,
|
|
@@ -44,7 +46,6 @@ from .._core import _imprimir_disponibles
|
|
|
44
46
|
|
|
45
47
|
__all__ = [
|
|
46
48
|
"summarize_dataframe",
|
|
47
|
-
"show_null_columns",
|
|
48
49
|
"drop_constant_columns",
|
|
49
50
|
"drop_rows_with_missing_data",
|
|
50
51
|
"drop_columns_with_missing_data",
|
|
@@ -61,10 +62,13 @@ __all__ = [
|
|
|
61
62
|
"plot_categorical_vs_target",
|
|
62
63
|
"plot_correlation_heatmap",
|
|
63
64
|
"encode_categorical_features",
|
|
65
|
+
"encode_classification_target",
|
|
64
66
|
"finalize_feature_schema",
|
|
65
67
|
"apply_feature_schema",
|
|
66
68
|
"reconstruct_from_schema",
|
|
67
69
|
"match_and_filter_columns_by_regex",
|
|
70
|
+
"show_null_columns",
|
|
71
|
+
"check_class_balance",
|
|
68
72
|
"standardize_percentages",
|
|
69
73
|
"reconstruct_one_hot",
|
|
70
74
|
"reconstruct_binary",
|
|
@@ -16,6 +16,7 @@ __all__ = [
|
|
|
16
16
|
"summarize_dataframe",
|
|
17
17
|
"show_null_columns",
|
|
18
18
|
"match_and_filter_columns_by_regex",
|
|
19
|
+
"check_class_balance",
|
|
19
20
|
]
|
|
20
21
|
|
|
21
22
|
|
|
@@ -212,3 +213,151 @@ def match_and_filter_columns_by_regex(
|
|
|
212
213
|
|
|
213
214
|
return filtered_df, matched_columns
|
|
214
215
|
|
|
216
|
+
|
|
217
|
+
def check_class_balance(
|
|
218
|
+
df: pd.DataFrame,
|
|
219
|
+
target: Union[str, list[str]],
|
|
220
|
+
plot_to_dir: Optional[Union[str, Path]] = None,
|
|
221
|
+
plot_filename: str = "Class_Balance"
|
|
222
|
+
) -> pd.DataFrame:
|
|
223
|
+
"""
|
|
224
|
+
Analyzes the class balance for classification targets.
|
|
225
|
+
|
|
226
|
+
Handles two cases:
|
|
227
|
+
1. Single Column (Binary/Multi-class): Calculates frequency of each unique value.
|
|
228
|
+
2. List of Columns (Multi-label Binary): Calculates the frequency of positive values (1) per column.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
df (pd.DataFrame): The input DataFrame.
|
|
232
|
+
target (str | list[str]): The target column name (for single/multi-class classification)
|
|
233
|
+
or list of column names (for multi-label-binary classification).
|
|
234
|
+
plot_to_dir (str | Path | None): Directory to save the balance plot.
|
|
235
|
+
plot_filename (str): Filename for the plot (without extension).
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
pd.DataFrame: Summary table of counts and percentages.
|
|
239
|
+
"""
|
|
240
|
+
# Early fail for empty DataFrame and handle list of targets with only one item
|
|
241
|
+
if df.empty:
|
|
242
|
+
_LOGGER.error("Input DataFrame is empty.")
|
|
243
|
+
raise ValueError()
|
|
244
|
+
|
|
245
|
+
if isinstance(target, list):
|
|
246
|
+
if len(target) == 0:
|
|
247
|
+
_LOGGER.error("Target list is empty.")
|
|
248
|
+
raise ValueError()
|
|
249
|
+
elif len(target) == 1:
|
|
250
|
+
target = target[0] # Simplify to single column case
|
|
251
|
+
|
|
252
|
+
# Case 1: Single Target (Binary or Multi-class)
|
|
253
|
+
if isinstance(target, str):
|
|
254
|
+
if target not in df.columns:
|
|
255
|
+
_LOGGER.error(f"Target column '{target}' not found in DataFrame.")
|
|
256
|
+
raise ValueError()
|
|
257
|
+
|
|
258
|
+
# Calculate stats
|
|
259
|
+
counts = df[target].value_counts(dropna=False).sort_index()
|
|
260
|
+
percents = df[target].value_counts(normalize=True, dropna=False).sort_index() * 100
|
|
261
|
+
|
|
262
|
+
summary = pd.DataFrame({
|
|
263
|
+
'Count': counts,
|
|
264
|
+
'Percentage': percents.round(2)
|
|
265
|
+
})
|
|
266
|
+
summary.index.name = "Class"
|
|
267
|
+
|
|
268
|
+
# Plotting
|
|
269
|
+
if plot_to_dir:
|
|
270
|
+
try:
|
|
271
|
+
save_path = make_fullpath(plot_to_dir, make=True, enforce="directory")
|
|
272
|
+
|
|
273
|
+
plt.figure(figsize=(10, 6))
|
|
274
|
+
# Convert index to str to handle numeric classes cleanly on x-axis
|
|
275
|
+
x_labels = summary.index.astype(str)
|
|
276
|
+
bars = plt.bar(x_labels, summary['Count'], color='lightgreen', edgecolor='black', alpha=0.7)
|
|
277
|
+
|
|
278
|
+
plt.title(f"Class Balance: {target}")
|
|
279
|
+
plt.xlabel(target)
|
|
280
|
+
plt.ylabel("Count")
|
|
281
|
+
plt.grid(axis='y', linestyle='--', alpha=0.5)
|
|
282
|
+
|
|
283
|
+
# Add percentage labels on top of bars
|
|
284
|
+
for bar, pct in zip(bars, summary['Percentage']):
|
|
285
|
+
height = bar.get_height()
|
|
286
|
+
plt.text(bar.get_x() + bar.get_width()/2, height,
|
|
287
|
+
f'{pct:.1f}%', ha='center', va='bottom', fontsize=10)
|
|
288
|
+
|
|
289
|
+
plt.tight_layout()
|
|
290
|
+
full_filename = sanitize_filename(plot_filename) + ".svg"
|
|
291
|
+
plt.savefig(save_path / full_filename, format='svg', bbox_inches="tight")
|
|
292
|
+
plt.close()
|
|
293
|
+
_LOGGER.info(f"Saved class balance plot: '{full_filename}'")
|
|
294
|
+
except Exception as e:
|
|
295
|
+
_LOGGER.error(f"Failed to plot class balance. Error: {e}")
|
|
296
|
+
plt.close()
|
|
297
|
+
|
|
298
|
+
return summary
|
|
299
|
+
|
|
300
|
+
# Case 2: Multi-label (List of binary columns)
|
|
301
|
+
elif isinstance(target, list):
|
|
302
|
+
missing_cols = [t for t in target if t not in df.columns]
|
|
303
|
+
if missing_cols:
|
|
304
|
+
_LOGGER.error(f"Target columns not found: {missing_cols}")
|
|
305
|
+
raise ValueError()
|
|
306
|
+
|
|
307
|
+
stats = []
|
|
308
|
+
for col in target:
|
|
309
|
+
# Assume 0/1 or False/True. Sum gives the count of positives.
|
|
310
|
+
# We enforce numeric to be safe
|
|
311
|
+
try:
|
|
312
|
+
numeric_series = pd.to_numeric(df[col], errors='coerce').fillna(0)
|
|
313
|
+
pos_count = numeric_series.sum()
|
|
314
|
+
total_count = len(df)
|
|
315
|
+
pct = (pos_count / total_count) * 100
|
|
316
|
+
except Exception:
|
|
317
|
+
_LOGGER.warning(f"Column '{col}' could not be processed as numeric. Assuming 0 positives.")
|
|
318
|
+
pos_count = 0
|
|
319
|
+
pct = 0.0
|
|
320
|
+
|
|
321
|
+
stats.append({
|
|
322
|
+
'Label': col,
|
|
323
|
+
'Positive_Count': int(pos_count),
|
|
324
|
+
'Positive_Percentage': round(pct, 2)
|
|
325
|
+
})
|
|
326
|
+
|
|
327
|
+
summary = pd.DataFrame(stats).set_index("Label").sort_values("Positive_Percentage", ascending=True)
|
|
328
|
+
|
|
329
|
+
# Plotting
|
|
330
|
+
if plot_to_dir:
|
|
331
|
+
try:
|
|
332
|
+
save_path = make_fullpath(plot_to_dir, make=True, enforce="directory")
|
|
333
|
+
|
|
334
|
+
# Dynamic height for many labels
|
|
335
|
+
height = max(6, len(target) * 0.4)
|
|
336
|
+
plt.figure(figsize=(10, height))
|
|
337
|
+
|
|
338
|
+
bars = plt.barh(summary.index, summary['Positive_Percentage'], color='lightgreen', edgecolor='black', alpha=0.7)
|
|
339
|
+
|
|
340
|
+
plt.title(f"Multi-label Binary Class Balance")
|
|
341
|
+
plt.xlabel("Positive Class Percentage (%)")
|
|
342
|
+
plt.xlim(0, 100)
|
|
343
|
+
plt.grid(axis='x', linestyle='--', alpha=0.5)
|
|
344
|
+
|
|
345
|
+
# Add count labels at the end of bars
|
|
346
|
+
for bar, count in zip(bars, summary['Positive_Count']):
|
|
347
|
+
width = bar.get_width()
|
|
348
|
+
plt.text(width + 1, bar.get_y() + bar.get_height()/2, f'{width:.1f}%', ha='left', va='center', fontsize=9)
|
|
349
|
+
|
|
350
|
+
plt.tight_layout()
|
|
351
|
+
full_filename = sanitize_filename(plot_filename) + ".svg"
|
|
352
|
+
plt.savefig(save_path / full_filename, format='svg', bbox_inches="tight")
|
|
353
|
+
plt.close()
|
|
354
|
+
_LOGGER.info(f"Saved multi-label balance plot: '{full_filename}'")
|
|
355
|
+
except Exception as e:
|
|
356
|
+
_LOGGER.error(f"Failed to plot class balance. Error: {e}")
|
|
357
|
+
plt.close()
|
|
358
|
+
|
|
359
|
+
return summary.sort_values("Positive_Percentage", ascending=False)
|
|
360
|
+
|
|
361
|
+
else:
|
|
362
|
+
_LOGGER.error("Target must be a string or a list of strings.")
|
|
363
|
+
raise TypeError()
|
|
@@ -3,7 +3,10 @@ from pandas.api.types import is_numeric_dtype, is_object_dtype
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
from typing import Any, Optional, Union
|
|
5
5
|
import re
|
|
6
|
+
import json
|
|
7
|
+
from pathlib import Path
|
|
6
8
|
|
|
9
|
+
from ..path_manager import make_fullpath
|
|
7
10
|
from .._core import get_logger
|
|
8
11
|
|
|
9
12
|
|
|
@@ -15,6 +18,7 @@ __all__ = [
|
|
|
15
18
|
"split_continuous_binary",
|
|
16
19
|
"split_continuous_categorical_targets",
|
|
17
20
|
"encode_categorical_features",
|
|
21
|
+
"encode_classification_target",
|
|
18
22
|
"reconstruct_one_hot",
|
|
19
23
|
"reconstruct_binary",
|
|
20
24
|
"reconstruct_multibinary",
|
|
@@ -263,6 +267,78 @@ def encode_categorical_features(
|
|
|
263
267
|
return df_encoded, mappings
|
|
264
268
|
|
|
265
269
|
|
|
270
|
+
def encode_classification_target(
|
|
271
|
+
df: pd.DataFrame,
|
|
272
|
+
target_col: str,
|
|
273
|
+
save_dir: Union[str, Path],
|
|
274
|
+
verbose: int = 2
|
|
275
|
+
) -> tuple[pd.DataFrame, dict[str, int]]:
|
|
276
|
+
"""
|
|
277
|
+
Encodes a target classification column into integers (0, 1, 2...) and saves the mapping to a JSON file.
|
|
278
|
+
|
|
279
|
+
This ensures that the target variable is in the correct numeric format for training
|
|
280
|
+
and provides a persistent artifact (the JSON file) to map predictions back to labels later.
|
|
281
|
+
|
|
282
|
+
Args:
|
|
283
|
+
df (pd.DataFrame): Input DataFrame.
|
|
284
|
+
target_col (str): Name of the target column to encode.
|
|
285
|
+
save_dir (str | Path): Directory where the class map JSON will be saved.
|
|
286
|
+
verbose (int): Verbosity level for logging.
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
Tuple (Dataframe, Dict):
|
|
290
|
+
- A new DataFrame with the target column encoded as integers.
|
|
291
|
+
- The dictionary mapping original labels (str) to integers (int).
|
|
292
|
+
"""
|
|
293
|
+
if target_col not in df.columns:
|
|
294
|
+
_LOGGER.error(f"Target column '{target_col}' not found in DataFrame.")
|
|
295
|
+
raise ValueError()
|
|
296
|
+
|
|
297
|
+
# Validation: Check for missing values in target
|
|
298
|
+
if df[target_col].isnull().any():
|
|
299
|
+
n_missing = df[target_col].isnull().sum()
|
|
300
|
+
_LOGGER.error(f"Target column '{target_col}' contains {n_missing} missing values. Please handle them before encoding.")
|
|
301
|
+
raise ValueError()
|
|
302
|
+
|
|
303
|
+
# Ensure directory exists
|
|
304
|
+
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
305
|
+
file_path = save_path / "class_map.json"
|
|
306
|
+
|
|
307
|
+
# Get unique values and sort them to ensure deterministic encoding (0, 1, 2...)
|
|
308
|
+
# Convert to string to ensure the keys in JSON are strings
|
|
309
|
+
unique_labels = sorted(df[target_col].astype(str).unique())
|
|
310
|
+
|
|
311
|
+
# Create mapping: { Label -> Integer }
|
|
312
|
+
class_map = {label: idx for idx, label in enumerate(unique_labels)}
|
|
313
|
+
|
|
314
|
+
# Apply mapping
|
|
315
|
+
# cast column to string to match the keys in class_map
|
|
316
|
+
df_encoded = df.copy()
|
|
317
|
+
df_encoded[target_col] = df_encoded[target_col].astype(str).map(class_map)
|
|
318
|
+
|
|
319
|
+
# Save to JSON
|
|
320
|
+
try:
|
|
321
|
+
with open(file_path, 'w', encoding='utf-8') as f:
|
|
322
|
+
json.dump(class_map, f, indent=4)
|
|
323
|
+
|
|
324
|
+
if verbose >= 2:
|
|
325
|
+
_LOGGER.info(f"Class mapping saved to: '{file_path}'")
|
|
326
|
+
|
|
327
|
+
if verbose >= 3:
|
|
328
|
+
_LOGGER.info(f"Target '{target_col}' encoded with {len(class_map)} classes.")
|
|
329
|
+
# Print a preview
|
|
330
|
+
if len(class_map) <= 10:
|
|
331
|
+
print(f" Mapping: {class_map}")
|
|
332
|
+
else:
|
|
333
|
+
print(f" Mapping (first 5): {dict(list(class_map.items())[:5])} ...")
|
|
334
|
+
|
|
335
|
+
except Exception as e:
|
|
336
|
+
_LOGGER.error(f"Failed to save class map JSON. Error: {e}")
|
|
337
|
+
raise IOError()
|
|
338
|
+
|
|
339
|
+
return df_encoded, class_map
|
|
340
|
+
|
|
341
|
+
|
|
266
342
|
def reconstruct_one_hot(
|
|
267
343
|
df: pd.DataFrame,
|
|
268
344
|
features_to_reconstruct: list[Union[str, tuple[str, Optional[str]]]],
|
ml_tools/keys/_keys.py
CHANGED
|
@@ -306,6 +306,7 @@ class _EvaluationConfig:
|
|
|
306
306
|
LOSS_PLOT_LEGEND_SIZE = 24
|
|
307
307
|
# CM settings
|
|
308
308
|
CM_SIZE = (9, 8) # used for multi label binary classification confusion matrix
|
|
309
|
+
NAME_LIMIT = 15 # max number of characters for feature/label names in plots
|
|
309
310
|
|
|
310
311
|
class _OneHotOtherPlaceholder:
|
|
311
312
|
"""Used internally by GUI_tools."""
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from ._single_resampling import (
|
|
2
|
+
DragonResampler,
|
|
3
|
+
)
|
|
4
|
+
|
|
5
|
+
from ._multi_resampling import (
|
|
6
|
+
DragonMultiResampler,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
from .._core import _imprimir_disponibles
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"DragonResampler",
|
|
14
|
+
"DragonMultiResampler",
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def info():
|
|
19
|
+
_imprimir_disponibles(__all__)
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import polars as pl
|
|
2
|
+
import pandas as pd
|
|
3
|
+
from typing import Union
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
__all__ = ["_DragonBaseResampler"]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class _DragonBaseResampler(ABC):
|
|
11
|
+
"""
|
|
12
|
+
Base class for Dragon resamplers handling common I/O and state.
|
|
13
|
+
"""
|
|
14
|
+
def __init__(self,
|
|
15
|
+
return_pandas: bool = False,
|
|
16
|
+
seed: int = 42):
|
|
17
|
+
self.return_pandas = return_pandas
|
|
18
|
+
self.seed = seed
|
|
19
|
+
|
|
20
|
+
def _convert_to_polars(self, df: Union[pd.DataFrame, pl.DataFrame]) -> pl.DataFrame:
|
|
21
|
+
"""Standardizes input to Polars DataFrame."""
|
|
22
|
+
if isinstance(df, pd.DataFrame):
|
|
23
|
+
return pl.from_pandas(df)
|
|
24
|
+
return df
|
|
25
|
+
|
|
26
|
+
def _convert_to_pandas(self, df: pl.DataFrame) -> pd.DataFrame:
|
|
27
|
+
"""Converts Polars DataFrame back to Pandas."""
|
|
28
|
+
return df.to_pandas(use_pyarrow_extension_array=False)
|
|
29
|
+
|
|
30
|
+
def _process_return(self, df: pl.DataFrame, shuffle: bool = True) -> Union[pd.DataFrame, pl.DataFrame]:
|
|
31
|
+
"""
|
|
32
|
+
Finalizes the DataFrame:
|
|
33
|
+
1. Global Shuffle (optional but recommended for ML).
|
|
34
|
+
2. Conversion to Pandas (if requested).
|
|
35
|
+
"""
|
|
36
|
+
if shuffle:
|
|
37
|
+
# Random shuffle of the final dataset
|
|
38
|
+
df = df.sample(fraction=1.0, seed=self.seed, with_replacement=False)
|
|
39
|
+
|
|
40
|
+
if self.return_pandas:
|
|
41
|
+
return self._convert_to_pandas(df)
|
|
42
|
+
return df
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def describe_balance(self, df: Union[pd.DataFrame, pl.DataFrame], top_n: int = 10) -> None:
|
|
46
|
+
"""
|
|
47
|
+
Prints a statistical summary of the target distribution.
|
|
48
|
+
"""
|
|
49
|
+
pass
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
import polars as pl
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import numpy as np
|
|
4
|
+
from typing import Union, Optional
|
|
5
|
+
|
|
6
|
+
from .._core import get_logger
|
|
7
|
+
|
|
8
|
+
from ._base_resampler import _DragonBaseResampler
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
_LOGGER = get_logger("DragonMultiResampler")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"DragonMultiResampler",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DragonMultiResampler(_DragonBaseResampler):
|
|
20
|
+
"""
|
|
21
|
+
A robust resampler for multi-label binary classification tasks using Polars.
|
|
22
|
+
|
|
23
|
+
It provides methods to downsample "all-negative" rows and balance the dataset
|
|
24
|
+
based on unique label combinations (Powerset).
|
|
25
|
+
"""
|
|
26
|
+
def __init__(self,
|
|
27
|
+
target_columns: list[str],
|
|
28
|
+
return_pandas: bool = False,
|
|
29
|
+
seed: int = 42):
|
|
30
|
+
"""
|
|
31
|
+
Args:
|
|
32
|
+
target_columns (List[str]): The list of binary target column names.
|
|
33
|
+
return_pandas (bool): Whether to return results as pandas DataFrame.
|
|
34
|
+
seed (int): Random seed for reproducibility.
|
|
35
|
+
"""
|
|
36
|
+
super().__init__(return_pandas=return_pandas, seed=seed)
|
|
37
|
+
self.targets = target_columns
|
|
38
|
+
|
|
39
|
+
def downsample_all_negatives(self,
|
|
40
|
+
df: Union[pd.DataFrame, pl.DataFrame],
|
|
41
|
+
negative_ratio: float = 1.0,
|
|
42
|
+
verbose: int = 2) -> Union[pd.DataFrame, pl.DataFrame]:
|
|
43
|
+
"""
|
|
44
|
+
Downsamples rows where ALL target columns are 0 ("background" class).
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
df (pd.DataFrame | pl.DataFrame): Input DataFrame.
|
|
48
|
+
negative_ratio (float): Ratio of negatives to positives to retain.
|
|
49
|
+
verbose (int): Verbosity level for logging.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
Dataframe: Resampled DataFrame.
|
|
53
|
+
"""
|
|
54
|
+
df_pl = self._convert_to_polars(df)
|
|
55
|
+
|
|
56
|
+
# 1. Identify "All Negative" vs "Has Signal"
|
|
57
|
+
fold_expr = pl.sum_horizontal(pl.col(self.targets)).cast(pl.UInt32)
|
|
58
|
+
|
|
59
|
+
df_pos = df_pl.filter(fold_expr > 0)
|
|
60
|
+
df_neg = df_pl.filter(fold_expr == 0)
|
|
61
|
+
|
|
62
|
+
n_pos = df_pos.height
|
|
63
|
+
n_neg_original = df_neg.height
|
|
64
|
+
|
|
65
|
+
if n_pos == 0:
|
|
66
|
+
if verbose >= 1:
|
|
67
|
+
_LOGGER.warning("No positive cases found in any label. Returning original DataFrame.")
|
|
68
|
+
return self._process_return(df_pl, shuffle=False)
|
|
69
|
+
|
|
70
|
+
# 2. Calculate target count for negatives
|
|
71
|
+
target_n_neg = int(n_pos * negative_ratio)
|
|
72
|
+
|
|
73
|
+
# 3. Sample if necessary
|
|
74
|
+
if n_neg_original > target_n_neg:
|
|
75
|
+
if verbose >= 2:
|
|
76
|
+
_LOGGER.info(f"📉 Downsampling 'All-Negative' rows from {n_neg_original} to {target_n_neg}")
|
|
77
|
+
|
|
78
|
+
# Here we use standard sampling because we are not grouping
|
|
79
|
+
df_neg_sampled = df_neg.sample(n=target_n_neg, seed=self.seed, with_replacement=False)
|
|
80
|
+
df_resampled = pl.concat([df_pos, df_neg_sampled])
|
|
81
|
+
|
|
82
|
+
return self._process_return(df_resampled)
|
|
83
|
+
else:
|
|
84
|
+
if verbose >= 1:
|
|
85
|
+
_LOGGER.warning(f"Negative count ({n_neg_original}) is already below target ({target_n_neg}). No downsampling applied.")
|
|
86
|
+
return self._process_return(df_pl, shuffle=False)
|
|
87
|
+
|
|
88
|
+
def balance_powerset(self,
|
|
89
|
+
df: Union[pd.DataFrame, pl.DataFrame],
|
|
90
|
+
max_samples_per_combination: Optional[int] = None,
|
|
91
|
+
quantile_limit: float = 0.90,
|
|
92
|
+
verbose: int = 2) -> Union[pd.DataFrame, pl.DataFrame]:
|
|
93
|
+
"""
|
|
94
|
+
Groups data by unique label combinations (Powerset) and downsamples
|
|
95
|
+
majority combinations.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
df (pd.DataFrame | pl.DataFrame): Input DataFrame.
|
|
99
|
+
max_samples_per_combination (int | None): Fixed cap per combination.
|
|
100
|
+
If None, uses quantile_limit to determine cap.
|
|
101
|
+
quantile_limit (float): Quantile to determine cap if max_samples_per_combination is None.
|
|
102
|
+
verbose (int): Verbosity level for logging.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
Dataframe: Resampled DataFrame.
|
|
106
|
+
"""
|
|
107
|
+
df_pl = self._convert_to_polars(df)
|
|
108
|
+
|
|
109
|
+
# 1. Create a hash/structural representation of the targets for grouping
|
|
110
|
+
df_lazy = df_pl.lazy().with_columns(
|
|
111
|
+
pl.concat_list(pl.col(self.targets)).alias("_powerset_key")
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# 2. Calculate frequencies
|
|
115
|
+
# We need to collect partially to calculate the quantile cap
|
|
116
|
+
combo_counts = df_lazy.group_by("_powerset_key").len().collect()
|
|
117
|
+
|
|
118
|
+
# Determine the Cap
|
|
119
|
+
if max_samples_per_combination is None:
|
|
120
|
+
# Handle potential None from quantile (satisfies linter)
|
|
121
|
+
q_val = combo_counts["len"].quantile(quantile_limit)
|
|
122
|
+
|
|
123
|
+
if q_val is None:
|
|
124
|
+
if verbose >= 1:
|
|
125
|
+
_LOGGER.warning("Data empty or insufficient to calculate quantile. Returning original.")
|
|
126
|
+
return self._process_return(df_pl, shuffle=False)
|
|
127
|
+
|
|
128
|
+
cap_size = int(q_val)
|
|
129
|
+
|
|
130
|
+
if verbose >= 3:
|
|
131
|
+
_LOGGER.info(f"📊 Auto-calculated Powerset Cap: {cap_size} samples (based on {quantile_limit} quantile).")
|
|
132
|
+
else:
|
|
133
|
+
cap_size = max_samples_per_combination
|
|
134
|
+
|
|
135
|
+
# 3. Apply Stratified Sampling / Capping (Randomized)
|
|
136
|
+
df_balanced = (
|
|
137
|
+
df_lazy
|
|
138
|
+
.filter(
|
|
139
|
+
pl.int_range(0, pl.len())
|
|
140
|
+
.shuffle(seed=self.seed)
|
|
141
|
+
.over("_powerset_key")
|
|
142
|
+
< cap_size
|
|
143
|
+
)
|
|
144
|
+
.drop("_powerset_key")
|
|
145
|
+
.collect()
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
if verbose >= 2:
|
|
149
|
+
original_count = df_pl.height
|
|
150
|
+
new_count = df_balanced.height
|
|
151
|
+
_LOGGER.info(f"⚖️ Powerset Balancing: Reduced from {original_count} to {new_count} rows.")
|
|
152
|
+
|
|
153
|
+
return self._process_return(df_balanced)
|
|
154
|
+
|
|
155
|
+
def describe_balance(self, df: Union[pd.DataFrame, pl.DataFrame], top_n: int = 10) -> None:
|
|
156
|
+
df_pl = self._convert_to_polars(df)
|
|
157
|
+
total_rows = df_pl.height
|
|
158
|
+
|
|
159
|
+
message_1 = f"\n📊 --- Target Balance Report ({total_rows} samples) ---\n🎯 Multi-Targets: {len(self.targets)} columns"
|
|
160
|
+
|
|
161
|
+
# A. Individual Label Counts
|
|
162
|
+
sums = df_pl.select([
|
|
163
|
+
pl.sum(col).alias(col) for col in self.targets
|
|
164
|
+
]).transpose(include_header=True, header_name="Label", column_names=["Count"])
|
|
165
|
+
|
|
166
|
+
sums = sums.with_columns(
|
|
167
|
+
(pl.col("Count") / total_rows * 100).round(2).alias("Percentage(%)")
|
|
168
|
+
).sort("Count", descending=True)
|
|
169
|
+
|
|
170
|
+
message_1 += "\n🔹 Individual Label Frequencies:"
|
|
171
|
+
|
|
172
|
+
# B. Powerset (Combination) Counts
|
|
173
|
+
message_2 = f"🔹 Top {top_n} Label Combinations (Powerset):"
|
|
174
|
+
|
|
175
|
+
combo_stats = (
|
|
176
|
+
df_pl.group_by(self.targets)
|
|
177
|
+
.len(name="Count")
|
|
178
|
+
.sort("Count", descending=True)
|
|
179
|
+
.with_columns(
|
|
180
|
+
(pl.col("Count") / total_rows * 100).round(2).alias("Percentage(%)")
|
|
181
|
+
)
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
_LOGGER.info(f"{message_1}\n{sums.head(top_n)}\n{message_2}\n{combo_stats.head(top_n)}")
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
import polars as pl
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import numpy as np
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
from .._core import get_logger
|
|
7
|
+
|
|
8
|
+
from ._base_resampler import _DragonBaseResampler
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
_LOGGER = get_logger("DragonResampler")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"DragonResampler",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DragonResampler(_DragonBaseResampler):
|
|
20
|
+
"""
|
|
21
|
+
A resampler for Single-Target Classification tasks (Binary or Multiclass).
|
|
22
|
+
|
|
23
|
+
It balances classes by downsampling majority classes relative to the size of the minority class.
|
|
24
|
+
"""
|
|
25
|
+
def __init__(self,
|
|
26
|
+
target_column: str,
|
|
27
|
+
return_pandas: bool = False,
|
|
28
|
+
seed: int = 42):
|
|
29
|
+
"""
|
|
30
|
+
Args:
|
|
31
|
+
target_column (str): The name of the single target column.
|
|
32
|
+
return_pandas (bool): Whether to return results as pandas DataFrame.
|
|
33
|
+
seed (int): Random seed for reproducibility.
|
|
34
|
+
"""
|
|
35
|
+
super().__init__(return_pandas=return_pandas, seed=seed)
|
|
36
|
+
self.target = target_column
|
|
37
|
+
|
|
38
|
+
def balance_classes(self,
|
|
39
|
+
df: Union[pd.DataFrame, pl.DataFrame],
|
|
40
|
+
majority_ratio: float = 1.0,
|
|
41
|
+
verbose: int = 2) -> Union[pd.DataFrame, pl.DataFrame]:
|
|
42
|
+
"""
|
|
43
|
+
Downsamples all classes to match the minority class size (scaled by a ratio).
|
|
44
|
+
"""
|
|
45
|
+
df_pl = self._convert_to_polars(df)
|
|
46
|
+
|
|
47
|
+
# 1. Calculate Class Counts
|
|
48
|
+
counts = df_pl.group_by(self.target).len().sort("len")
|
|
49
|
+
|
|
50
|
+
if counts.height == 0:
|
|
51
|
+
_LOGGER.error("DataFrame is empty or target column missing.")
|
|
52
|
+
return self._process_return(df_pl, shuffle=False)
|
|
53
|
+
|
|
54
|
+
# 2. Identify Statistics
|
|
55
|
+
min_val = counts["len"].min()
|
|
56
|
+
max_val = counts["len"].max()
|
|
57
|
+
|
|
58
|
+
if min_val is None or max_val is None:
|
|
59
|
+
_LOGGER.error("Failed to calculate class statistics (unexpected None).")
|
|
60
|
+
raise ValueError()
|
|
61
|
+
|
|
62
|
+
minority_count: int = min_val # type: ignore
|
|
63
|
+
majority_count: int = max_val # type: ignore
|
|
64
|
+
|
|
65
|
+
# Calculate the cap
|
|
66
|
+
cap_size = int(minority_count * majority_ratio)
|
|
67
|
+
|
|
68
|
+
if verbose >= 3:
|
|
69
|
+
_LOGGER.info(f"📊 Class Distribution:\n{counts}")
|
|
70
|
+
_LOGGER.info(f"🎯 Strategy: Cap majorities at {cap_size}")
|
|
71
|
+
|
|
72
|
+
# Optimization: If data is already balanced enough
|
|
73
|
+
if majority_count <= cap_size:
|
|
74
|
+
if verbose >= 2:
|
|
75
|
+
_LOGGER.info("Data is already within the requested balance ratio.")
|
|
76
|
+
return self._process_return(df_pl, shuffle=False)
|
|
77
|
+
|
|
78
|
+
# 3. Apply Downsampling (Randomized)
|
|
79
|
+
# We generate a random range index per group and filter by it.
|
|
80
|
+
# This ensures we pick a random subset, not the first N rows.
|
|
81
|
+
df_balanced = (
|
|
82
|
+
df_pl.lazy()
|
|
83
|
+
.filter(
|
|
84
|
+
pl.int_range(0, pl.len())
|
|
85
|
+
.shuffle(seed=self.seed)
|
|
86
|
+
.over(self.target)
|
|
87
|
+
< cap_size
|
|
88
|
+
)
|
|
89
|
+
.collect()
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
if verbose >= 2:
|
|
93
|
+
reduced_count = df_balanced.height
|
|
94
|
+
_LOGGER.info(f"⚖️ Balancing Complete: {df_pl.height} -> {reduced_count} rows.")
|
|
95
|
+
|
|
96
|
+
return self._process_return(df_balanced)
|
|
97
|
+
|
|
98
|
+
def describe_balance(self, df: Union[pd.DataFrame, pl.DataFrame], top_n: int = 10) -> None:
|
|
99
|
+
df_pl = self._convert_to_polars(df)
|
|
100
|
+
total_rows = df_pl.height
|
|
101
|
+
|
|
102
|
+
message = f"\n📊 --- Target Balance Report ({total_rows} samples) ---\n🎯 Single Target: '{self.target}'"
|
|
103
|
+
|
|
104
|
+
stats = (
|
|
105
|
+
df_pl.group_by(self.target)
|
|
106
|
+
.len(name="Count")
|
|
107
|
+
.sort("Count", descending=True)
|
|
108
|
+
.with_columns(
|
|
109
|
+
(pl.col("Count") / total_rows * 100).round(2).alias("Percentage(%)")
|
|
110
|
+
)
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
_LOGGER.info(f"{message}\n{stats.head(top_n)}")
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|