dragon-ml-toolbox 20.7.1__py3-none-any.whl → 20.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 20.7.1
3
+ Version: 20.9.0
4
4
  Summary: Complete pipelines and helper tools for data science and machine learning projects.
5
5
  Author-email: Karl Luigi Loza Vidaurre <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -174,6 +174,7 @@ ML_vision_transformers
174
174
  optimization_tools
175
175
  path_manager
176
176
  plot_fonts
177
+ resampling
177
178
  schema
178
179
  serde
179
180
  SQL
@@ -206,6 +207,7 @@ optimization_tools
206
207
  path_manager
207
208
  plot_fonts
208
209
  PSO_optimization
210
+ resampling
209
211
  schema
210
212
  serde
211
213
  SQL
@@ -1,5 +1,5 @@
1
- dragon_ml_toolbox-20.7.1.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
2
- dragon_ml_toolbox-20.7.1.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=0-HBRMMgKuwtGy6nMJZvIn1fLxhx_ksyyVB2U_iyYZU,2818
1
+ dragon_ml_toolbox-20.9.0.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
2
+ dragon_ml_toolbox-20.9.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=0-HBRMMgKuwtGy6nMJZvIn1fLxhx_ksyyVB2U_iyYZU,2818
3
3
  ml_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  ml_tools/constants.py,sha256=3br5Rk9cL2IUo638eJuMOGdbGQaWssaUecYEvSeRBLM,3322
5
5
  ml_tools/ETL_cleaning/__init__.py,sha256=gLRHF-qzwpqKTvbbn9chIQELeUDh_XGpBRX28j-5IqI,545
@@ -30,19 +30,20 @@ ml_tools/ML_chain/_update_schema.py,sha256=z1Us7lv6hy6GwSu1mcid50Jmqq3sh91hMQ0Ln
30
30
  ml_tools/ML_configuration/__init__.py,sha256=ogktFnYxz5jWJkhHS4DVaMldHkt3lT2gw9jx5PQ3d78,2755
31
31
  ml_tools/ML_configuration/_base_model_config.py,sha256=95L3IfobNFMtnNr79zYpDGerC1q1v7M05tWZvTS2cwE,2247
32
32
  ml_tools/ML_configuration/_finalize.py,sha256=l_n13bLu0avMdJ8hNRrH8V_wOBQZM1UGsTydKBkTysM,15047
33
- ml_tools/ML_configuration/_metrics.py,sha256=xKtEKzphtidwwU8UuUpGv4B8Y6Bv0tAOjEFUYfz8Ehc,23758
33
+ ml_tools/ML_configuration/_metrics.py,sha256=KJM7HQeoEmJUUUrxNa4wYf2N9NawGPJoy7AGdNO3gxQ,24059
34
34
  ml_tools/ML_configuration/_models.py,sha256=lvuuqvD6DWUzOa3i06NZfrdfOi9bu2e26T_QO6BGMSw,7629
35
35
  ml_tools/ML_configuration/_training.py,sha256=_M_TwouHFNbGrZQtQNAvyG_poSVpmN99cbyUonZsHhk,8969
36
36
  ml_tools/ML_datasetmaster/__init__.py,sha256=UltQzuXnlXVCkD-aeA5TW4IcMVLnQf1_aglawg4WyrI,580
37
- ml_tools/ML_datasetmaster/_base_datasetmaster.py,sha256=lmqo9CN09xMu-YKYtKEnC2ZEzkxcZFJ0rS1B7K2-PKY,14691
37
+ ml_tools/ML_datasetmaster/_base_datasetmaster.py,sha256=IgyVzRY3mlKDyBDklawvPF9SMjZFu8T2red6M-3MlQ4,16074
38
38
  ml_tools/ML_datasetmaster/_datasetmaster.py,sha256=Oy2UE3YJpKTaFwQF5TkQLgLB54-BFw_5b8wIPTxZIKU,19157
39
39
  ml_tools/ML_datasetmaster/_sequence_datasetmaster.py,sha256=cW3fuILZWs-7Yuo4T2fgGfTC4vwho3Gp4ohIKJYS7O0,18452
40
40
  ml_tools/ML_datasetmaster/_vision_datasetmaster.py,sha256=kvSqXYeNBN1JSRfSEEXYeIcsqy9HsJAl_EwFWClqlsw,67025
41
41
  ml_tools/ML_evaluation/__init__.py,sha256=e3c8JNP0tt4Kxc7QSQpGcOgrxf8JAucH4UkJvJxUL2E,1122
42
- ml_tools/ML_evaluation/_classification.py,sha256=8bKQejKrgMipnxU1T12ted7p60xvJS0d0MvHtdNBCBM,30971
42
+ ml_tools/ML_evaluation/_classification.py,sha256=0URqIhNEgWedy-SYRmIJ2ejLKqatiuOU7qelJ6Cv3OE,33939
43
43
  ml_tools/ML_evaluation/_feature_importance.py,sha256=mTwi3LKom_axu6UFKunELj30APDdhG9GQC2w7I9mYhI,17137
44
+ ml_tools/ML_evaluation/_helpers.py,sha256=kE1TSYIOAAcYI1EjdudyTfFeU47Wrl0E9eNL1EOwbKg,1217
44
45
  ml_tools/ML_evaluation/_loss.py,sha256=1a4O25i3Ya_3naNZNL7ELLUL46BY86g1scA7d7q2UFM,3625
45
- ml_tools/ML_evaluation/_regression.py,sha256=hnT2B2_6AnQ7aA7uk-X2lZL9G5JFGCduDXyZbr1gFCA,11037
46
+ ml_tools/ML_evaluation/_regression.py,sha256=UZA7_fg85ZKJQWszioWDtmkplSiXeHJk2fBYR5bRXHY,11225
46
47
  ml_tools/ML_evaluation/_sequence.py,sha256=gUk9Uvmy7MrXkfrriMnfypkgJU5XERHdqekTa2gBaOM,8004
47
48
  ml_tools/ML_evaluation/_vision.py,sha256=abBHQ6Z2GunHNusL3wcLgfI1FVNA6hBUBTq1eOA8FSA,11489
48
49
  ml_tools/ML_evaluation_captum/_ML_evaluation_captum.py,sha256=6g3ymSxJGHXxwIN7WCD2Zi9zxKWEv-Qskd2cCGQQJ5Y,18439
@@ -103,10 +104,10 @@ ml_tools/_core/__init__.py,sha256=m-VP0RW0tOTm9N5NI3kFNcpM7WtVgs0RK9pK3ZJRZQQ,14
103
104
  ml_tools/_core/_logger.py,sha256=xzhn_FouMDRVNwXGBGlPC9Ruq6i5uCrmNaS5jesguMU,4972
104
105
  ml_tools/_core/_schema_load_ops.py,sha256=KLs9vBzANz5ESe2wlP-C41N4VlgGil-ywcfvWKSOGss,1551
105
106
  ml_tools/_core/_script_info.py,sha256=LtFGt10gEvCnhIRMKJPi2yXkiGLcdr7lE-oIP2XGHzQ,234
106
- ml_tools/data_exploration/__init__.py,sha256=nYKg1bPBgXibC5nhmNKPw3VaKFeVtlNGL_YpHixW-Pg,1795
107
- ml_tools/data_exploration/_analysis.py,sha256=H6LryV56FFCHWjvQdkhZbtprZy6aP8EqU_hC2Cf9CLE,7832
107
+ ml_tools/data_exploration/__init__.py,sha256=efUBsruHL56B429tUadl3PdG73zAF639Y430uMQRfko,1917
108
+ ml_tools/data_exploration/_analysis.py,sha256=PJNrEBz5ZZXHoUlQ6fh9Y86nzPQrLpVPv2Ye4NfOxgs,14181
108
109
  ml_tools/data_exploration/_cleaning.py,sha256=pAZOXgGK35j7O8q6cnyTwYK1GLNnD04A8p2fSyMB1mg,20906
109
- ml_tools/data_exploration/_features.py,sha256=wW-M8n2aLIy05DR2z4fI8wjpPjn3mOAnm9aSGYbMKwI,23363
110
+ ml_tools/data_exploration/_features.py,sha256=Z1noJfDxBzFRfusFp6NlpLF2NItuZuzFHq4ssWFqny4,26273
110
111
  ml_tools/data_exploration/_plotting.py,sha256=zH1dPcIoAlOuww23xIoBCsQOAshPPv9OyGposOA2RvI,19883
111
112
  ml_tools/data_exploration/_schema_ops.py,sha256=Fd6fBGGv4OpxmJ1HG9pith6QL90z0tzssCvzkQxlEEQ,11083
112
113
  ml_tools/ensemble_evaluation/__init__.py,sha256=t4Gr8EGEk8RLatyc92-S0BzbQvdvodzoF-qDAH2qjVg,546
@@ -118,7 +119,7 @@ ml_tools/ensemble_learning/_ensemble_learning.py,sha256=MHDZBR20_nStlSSeThFI3bSu
118
119
  ml_tools/excel_handler/__init__.py,sha256=AaWM3n_dqBhJLTs3OEA57ex5YykKXNOwVCyHlVsdnqI,530
119
120
  ml_tools/excel_handler/_excel_handler.py,sha256=TODudmeQgDSdxUKzLfAzizs--VL-g8WxDOfQ4sgxxLs,13965
120
121
  ml_tools/keys/__init__.py,sha256=-0c2pmrhyfROc-oQpEjJGLBMhSagA3CyFijQaaqZRqU,399
121
- ml_tools/keys/_keys.py,sha256=lL9NlijxOEAhfDPPqK_wL3QhjalrYK_fWM-KNniSIOA,9308
122
+ ml_tools/keys/_keys.py,sha256=56hlyPl2VUMsq7cFFLBypWHr-JU6ehWGwZG38l6IjI0,9389
122
123
  ml_tools/math_utilities/__init__.py,sha256=K7Obkkc4rPKj4EbRZf1BsXHfiCg7FXYv_aN9Yc2Z_Vg,400
123
124
  ml_tools/math_utilities/_math_utilities.py,sha256=BYHIVcM9tuKIhVrkgLLiM5QalJ39zx7dXYy_M9aGgiM,9012
124
125
  ml_tools/optimization_tools/__init__.py,sha256=KD8JXpfGuPndO4AHnjJGu6uV1GRwhOfboD0KZV45kzw,658
@@ -129,6 +130,10 @@ ml_tools/path_manager/_dragonmanager.py,sha256=q9wHTKPmdzywEz6N14ipUoeR3MmW0bzB4
129
130
  ml_tools/path_manager/_path_tools.py,sha256=LcZE31QlkzZWUR8g1MW_N_mPY2DpKBJLA45VJz7ZYsw,11905
130
131
  ml_tools/plot_fonts/__init__.py,sha256=KIxXRCjQ3SliEoLhEcqs7zDVZbVTn38bmSdL-yR1Q2w,187
131
132
  ml_tools/plot_fonts/_plot_fonts.py,sha256=mfjXNT9P59ymHoTI85Q8CcvfxfK5BIFBWtTZH-hNIC4,2209
133
+ ml_tools/resampling/__init__.py,sha256=WB1YlNQgOIdSSQn-9eCIaiB0AHLSCkziFufqa-1QBG0,278
134
+ ml_tools/resampling/_base_resampler.py,sha256=8IqkEJ7uiAiC9bqbKfsC-5vIvrN3EwH7lLVDlRKQzM8,1617
135
+ ml_tools/resampling/_multi_resampling.py,sha256=m_iVvXPAu3p_EoBt2VZpgjhPLY1LmKa8fGtQo5E0pWk,7199
136
+ ml_tools/resampling/_single_resampling.py,sha256=zKL4Br7Lm4Jq90X-ewQ6AKTsP923bq9RIMnTxIxtXBc,3896
132
137
  ml_tools/schema/__init__.py,sha256=K6uiZ9f0GCQ7etw1yl2-dQVLhU7RkL3KHesO3HNX6v4,334
133
138
  ml_tools/schema/_feature_schema.py,sha256=MuPf6Nf7tDhUTGyX7tcFHZh-lLSNsJkLmlf9IxdF4O4,9660
134
139
  ml_tools/schema/_gui_schema.py,sha256=IVwN4THAdFrvh2TpV4SFd_zlzMX3eioF-w-qcSVTndE,7245
@@ -138,7 +143,7 @@ ml_tools/utilities/__init__.py,sha256=h4lE3SQstg-opcQj6QSKhu-HkqSbmHExsWoM9vC5D9
138
143
  ml_tools/utilities/_translate.py,sha256=U8hRPa3PmTpIf9n9yR3gBGmp_hkcsjQLwjAHSHc0WHs,10325
139
144
  ml_tools/utilities/_utility_save_load.py,sha256=EFvFaTaHahDQWdJWZr-j7cHqRbG_Xrpc96228JhV-bs,16773
140
145
  ml_tools/utilities/_utility_tools.py,sha256=bN0J9d1S0W5wNzNntBWqDsJcEAK7-1OgQg3X2fwXns0,6918
141
- dragon_ml_toolbox-20.7.1.dist-info/METADATA,sha256=IB7aIajHgmlg0UvpBOjDfCiQWfNmM0G3NKSpiEvDlAs,7866
142
- dragon_ml_toolbox-20.7.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
143
- dragon_ml_toolbox-20.7.1.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
144
- dragon_ml_toolbox-20.7.1.dist-info/RECORD,,
146
+ dragon_ml_toolbox-20.9.0.dist-info/METADATA,sha256=ehKhp6BpCkHcZnWpcoZU53rn4T0yI0Dboq3eH2vx8LU,7888
147
+ dragon_ml_toolbox-20.9.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
148
+ dragon_ml_toolbox-20.9.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
149
+ dragon_ml_toolbox-20.9.0.dist-info/RECORD,,
@@ -98,10 +98,11 @@ class _BaseMultiLabelFormat:
98
98
  cmap: str = "BuGn",
99
99
  ROC_PR_line: str='darkorange',
100
100
  calibration_bins: Union[int, Literal['auto']]='auto',
101
- font_size: int = 25,
102
- xtick_size: int=20,
103
- ytick_size: int=20,
104
- legend_size: int=23) -> None:
101
+ font_size: int = 26,
102
+ xtick_size: int=22,
103
+ ytick_size: int=22,
104
+ legend_size: int=26,
105
+ cm_font_size: int=26) -> None:
105
106
  """
106
107
  Initializes the formatting configuration for multi-label classification metrics.
107
108
 
@@ -127,6 +128,8 @@ class _BaseMultiLabelFormat:
127
128
 
128
129
  legend_size (int): Font size for plot legends.
129
130
 
131
+ cm_font_size (int): Font size for the confusion matrix.
132
+
130
133
  <br>
131
134
 
132
135
  ### [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
@@ -142,6 +145,7 @@ class _BaseMultiLabelFormat:
142
145
  self.xtick_size = xtick_size
143
146
  self.ytick_size = ytick_size
144
147
  self.legend_size = legend_size
148
+ self.cm_font_size = cm_font_size
145
149
 
146
150
  def __repr__(self) -> str:
147
151
  parts = [
@@ -151,7 +155,8 @@ class _BaseMultiLabelFormat:
151
155
  f"font_size={self.font_size}",
152
156
  f"xtick_size={self.xtick_size}",
153
157
  f"ytick_size={self.ytick_size}",
154
- f"legend_size={self.legend_size}"
158
+ f"legend_size={self.legend_size}",
159
+ f"cm_font_size={self.cm_font_size}"
155
160
  ]
156
161
  return f"{self.__class__.__name__}({', '.join(parts)})"
157
162
 
@@ -520,10 +525,11 @@ class FormatMultiLabelBinaryClassificationMetrics(_BaseMultiLabelFormat):
520
525
  cmap: str = "BuGn",
521
526
  ROC_PR_line: str='darkorange',
522
527
  calibration_bins: Union[int, Literal['auto']]='auto',
523
- font_size: int = 25,
524
- xtick_size: int=20,
525
- ytick_size: int=20,
526
- legend_size: int=23
528
+ font_size: int = 26,
529
+ xtick_size: int=22,
530
+ ytick_size: int=22,
531
+ legend_size: int=26,
532
+ cm_font_size: int=26
527
533
  ) -> None:
528
534
  super().__init__(cmap=cmap,
529
535
  ROC_PR_line=ROC_PR_line,
@@ -531,7 +537,8 @@ class FormatMultiLabelBinaryClassificationMetrics(_BaseMultiLabelFormat):
531
537
  font_size=font_size,
532
538
  xtick_size=xtick_size,
533
539
  ytick_size=ytick_size,
534
- legend_size=legend_size)
540
+ legend_size=legend_size,
541
+ cm_font_size=cm_font_size)
535
542
 
536
543
 
537
544
  # Segmentation
@@ -133,7 +133,7 @@ class _BaseDatasetMaker(ABC):
133
133
 
134
134
  # Get continuous feature indices *from the schema*
135
135
  if schema.continuous_feature_names:
136
- if verbose >= 2:
136
+ if verbose >= 3:
137
137
  _LOGGER.info("Getting continuous feature indices from schema.")
138
138
  try:
139
139
  # Convert columns to a standard list for .index()
@@ -189,7 +189,7 @@ class _BaseDatasetMaker(ABC):
189
189
  # ------------------------------------------------------------------
190
190
 
191
191
  if self.target_scaler is None:
192
- if verbose >= 2:
192
+ if verbose >= 3:
193
193
  _LOGGER.info("Fitting a new DragonScaler on training targets.")
194
194
  # Convert to float tensor for calculation
195
195
  y_train_tensor = torch.tensor(y_train_arr, dtype=torch.float32)
@@ -202,6 +202,9 @@ class _BaseDatasetMaker(ABC):
202
202
  y_val_tensor = self.target_scaler.transform(torch.tensor(y_val_arr, dtype=torch.float32))
203
203
  y_test_tensor = self.target_scaler.transform(torch.tensor(y_test_arr, dtype=torch.float32))
204
204
  return y_train_tensor.numpy(), y_val_tensor.numpy(), y_test_tensor.numpy()
205
+
206
+ if verbose >= 2:
207
+ _LOGGER.info("Target scaling transformation complete.")
205
208
 
206
209
  return y_train_arr, y_val_arr, y_test_arr
207
210
 
@@ -214,6 +217,9 @@ class _BaseDatasetMaker(ABC):
214
217
 
215
218
  @property
216
219
  def train_dataset(self) -> Dataset:
220
+ """
221
+ Returns the training dataset.
222
+ """
217
223
  if self._train_ds is None:
218
224
  _LOGGER.error("Train Dataset not yet created.")
219
225
  raise RuntimeError()
@@ -221,6 +227,9 @@ class _BaseDatasetMaker(ABC):
221
227
 
222
228
  @property
223
229
  def validation_dataset(self) -> Dataset:
230
+ """
231
+ Returns the validation dataset.
232
+ """
224
233
  if self._val_ds is None:
225
234
  _LOGGER.error("Validation Dataset not yet created.")
226
235
  raise RuntimeError()
@@ -228,6 +237,9 @@ class _BaseDatasetMaker(ABC):
228
237
 
229
238
  @property
230
239
  def test_dataset(self) -> Dataset:
240
+ """
241
+ Returns the test dataset.
242
+ """
231
243
  if self._test_ds is None:
232
244
  _LOGGER.error("Test Dataset not yet created.")
233
245
  raise RuntimeError()
@@ -235,30 +247,50 @@ class _BaseDatasetMaker(ABC):
235
247
 
236
248
  @property
237
249
  def feature_names(self) -> list[str]:
250
+ """
251
+ Returns a list with the feature names.
252
+ """
238
253
  return self._feature_names
239
254
 
240
255
  @property
241
256
  def target_names(self) -> list[str]:
257
+ """
258
+ Returns a list with the target names.
259
+ """
242
260
  return self._target_names
243
261
 
244
262
  @property
245
263
  def number_of_features(self) -> int:
264
+ """
265
+ Returns the number of features.
266
+ """
246
267
  return len(self._feature_names)
247
268
 
248
269
  @property
249
270
  def number_of_targets(self) -> int:
271
+ """
272
+ Returns the number of targets.
273
+ """
250
274
  return len(self._target_names)
251
275
 
252
276
  @property
253
277
  def id(self) -> Optional[str]:
278
+ """
279
+ Returns the dataset ID if set, otherwise None.
280
+ """
254
281
  return self._id
255
282
 
256
283
  @id.setter
257
284
  def id(self, dataset_id: str):
258
- if not isinstance(dataset_id, str): raise ValueError("ID must be a string.")
285
+ if not isinstance(dataset_id, str):
286
+ _LOGGER.error("Dataset ID must be a string.")
287
+ raise ValueError()
259
288
  self._id = dataset_id
260
289
 
261
290
  def dataframes_info(self) -> None:
291
+ """
292
+ Prints the shapes of the dataframes after the split.
293
+ """
262
294
  print("--- DataFrame Shapes After Split ---")
263
295
  print(f" X_train shape: {self._X_train_shape}, y_train shape: {self._y_train_shape}")
264
296
  print(f" X_val shape: {self._X_val_shape}, y_val shape: {self._y_val_shape}")
@@ -266,12 +298,26 @@ class _BaseDatasetMaker(ABC):
266
298
  print("------------------------------------")
267
299
 
268
300
  def save_feature_names(self, directory: Union[str, Path], verbose: bool=True) -> None:
301
+ """
302
+ Saves the feature names to a text file.
303
+
304
+ Args:
305
+ directory (str | Path): Directory to save the feature names.
306
+ verbose (bool): Whether to print log messages.
307
+ """
269
308
  save_list_strings(list_strings=self._feature_names,
270
309
  directory=directory,
271
310
  filename=DatasetKeys.FEATURE_NAMES,
272
311
  verbose=verbose)
273
312
 
274
313
  def save_target_names(self, directory: Union[str, Path], verbose: bool=True) -> None:
314
+ """
315
+ Saves the target names to a text file.
316
+
317
+ Args:
318
+ directory (str | Path): Directory to save the target names.
319
+ verbose (bool): Whether to print log messages.
320
+ """
275
321
  save_list_strings(list_strings=self._target_names,
276
322
  directory=directory,
277
323
  filename=DatasetKeys.TARGET_NAMES,
@@ -281,6 +327,10 @@ class _BaseDatasetMaker(ABC):
281
327
  """
282
328
  Saves both feature and target scalers (if they exist) to a single .pth file
283
329
  using a dictionary structure.
330
+
331
+ Args:
332
+ directory (str | Path): Directory to save the scaler.
333
+ verbose (bool): Whether to print log messages.
284
334
  """
285
335
  if self.feature_scaler is None and self.target_scaler is None:
286
336
  _LOGGER.warning("No scalers (feature or target) were fitted. Nothing to save.")
@@ -28,6 +28,8 @@ from ..path_manager import make_fullpath, sanitize_filename
28
28
  from .._core import get_logger
29
29
  from ..keys._keys import _EvaluationConfig
30
30
 
31
+ from ._helpers import check_and_abbreviate_name
32
+
31
33
 
32
34
  _LOGGER = get_logger("Classification Metrics")
33
35
 
@@ -85,7 +87,8 @@ def classification_metrics(save_dir: Union[str, Path],
85
87
  try:
86
88
  sorted_items = sorted(class_map.items(), key=lambda item: item[1])
87
89
  map_labels = [item[1] for item in sorted_items]
88
- map_display_labels = [item[0] for item in sorted_items]
90
+ # Abbreviate display labels if needed
91
+ map_display_labels = [check_and_abbreviate_name(item[0]) for item in sorted_items]
89
92
  except Exception as e:
90
93
  _LOGGER.warning(f"Could not parse 'class_map': {e}")
91
94
  map_labels = None
@@ -397,6 +400,10 @@ def classification_metrics(save_dir: Union[str, Path],
397
400
  # --- Step 1: Get binned data directly ---
398
401
  # calculates reliability diagram data without needing a temporary plot
399
402
  prob_true, prob_pred = calibration_curve(y_true_binary, y_score, n_bins=dynamic_bins)
403
+
404
+ # Anchor the plot to (0,0) and (1,1) to ensure the line spans the full diagonal
405
+ prob_true = np.concatenate(([0.0], prob_true, [1.0]))
406
+ prob_pred = np.concatenate(([0.0], prob_pred, [1.0]))
400
407
 
401
408
  # --- Step 2: Plot ---
402
409
  ax_cal.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')
@@ -467,6 +474,9 @@ def multi_label_classification_metrics(
467
474
 
468
475
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
469
476
 
477
+ # --- Pre-process target names for abbreviation ---
478
+ target_names = [check_and_abbreviate_name(name) for name in target_names]
479
+
470
480
  # --- Parse Config or use defaults ---
471
481
  if config is None:
472
482
  # Create a default config if one wasn't provided
@@ -481,6 +491,10 @@ def multi_label_classification_metrics(
481
491
  ytick_size = format_config.ytick_size
482
492
  legend_size = format_config.legend_size
483
493
  base_font_size = format_config.font_size
494
+
495
+ # config font size for heatmap
496
+ cm_font_size = format_config.cm_font_size
497
+ cm_tick_size = cm_font_size - 4
484
498
 
485
499
  # --- Calculate and Save Overall Metrics (using y_pred) ---
486
500
  h_loss = hamming_loss(y_true, y_pred)
@@ -488,7 +502,7 @@ def multi_label_classification_metrics(
488
502
  j_score_macro = jaccard_score(y_true, y_pred, average='macro')
489
503
 
490
504
  overall_report = (
491
- f"Overall Multi-Label Metrics:\n" # No threshold to report here
505
+ f"Overall Multi-Label Metrics:\n"
492
506
  f"--------------------------------------------------\n"
493
507
  f"Hamming Loss: {h_loss:.4f}\n"
494
508
  f"Jaccard Score (micro): {j_score_micro:.4f}\n"
@@ -499,9 +513,65 @@ def multi_label_classification_metrics(
499
513
  overall_report_path = save_dir_path / "classification_report.txt"
500
514
  overall_report_path.write_text(overall_report)
501
515
 
516
+ # --- Save Classification Report Heatmap (Multi-label) ---
517
+ try:
518
+ # Generate full report as dict
519
+ full_report_dict = classification_report(y_true, y_pred, target_names=target_names, output_dict=True)
520
+ report_df = pd.DataFrame(full_report_dict)
521
+
522
+ # Cleanup
523
+ # Remove 'accuracy' column if it exists
524
+ report_df = report_df.drop(columns=['accuracy'], errors='ignore')
525
+
526
+ # Remove 'support' row explicitly
527
+ if 'support' in report_df.index:
528
+ report_df = report_df.drop(index='support')
529
+
530
+ # Transpose: Rows = Classes/Averages, Cols = Metrics
531
+ plot_df = report_df.T
532
+
533
+ # Dynamic Height
534
+ fig_height = max(5.0, len(plot_df.index) * 0.5 + 4.0)
535
+ fig_width = 8.0
536
+
537
+ fig_heat, ax_heat = plt.subplots(figsize=(fig_width, fig_height), dpi=_EvaluationConfig.DPI)
538
+
539
+ # Plot
540
+ sns.heatmap(plot_df,
541
+ annot=True,
542
+ cmap=format_config.cmap,
543
+ fmt='.2f',
544
+ vmin=0.0,
545
+ vmax=1.0,
546
+ cbar_kws={'shrink': 0.9})
547
+
548
+ ax_heat.set_title("Classification Report Heatmap", pad=_EvaluationConfig.LABEL_PADDING, fontsize=cm_font_size)
549
+
550
+ # manually increase the font size of the elements
551
+ for text in ax_heat.texts:
552
+ text.set_fontsize(cm_tick_size)
553
+
554
+ cbar = ax_heat.collections[0].colorbar
555
+ cbar.ax.tick_params(labelsize=cm_tick_size - 4) # type: ignore
556
+
557
+ ax_heat.tick_params(axis='x', labelsize=cm_tick_size, pad=_EvaluationConfig.LABEL_PADDING)
558
+ ax_heat.tick_params(axis='y', labelsize=cm_tick_size, pad=_EvaluationConfig.LABEL_PADDING, rotation=0)
559
+
560
+ plt.tight_layout()
561
+ heatmap_path = save_dir_path / "classification_report_heatmap.svg"
562
+ plt.savefig(heatmap_path)
563
+ _LOGGER.info(f"📊 Report heatmap saved as '{heatmap_path.name}'")
564
+ plt.close(fig_heat)
565
+
566
+ except Exception as e:
567
+ _LOGGER.error(f"Could not generate multi-label classification report heatmap: {e}")
568
+
502
569
  # --- Per-Label Metrics and Plots ---
503
570
  for i, name in enumerate(target_names):
504
- print(f" -> Evaluating label: '{name}'")
571
+ # strip whitespace from name
572
+ name = name.strip()
573
+
574
+ # print(f" -> Evaluating label: '{name}'")
505
575
  true_i = y_true[:, i]
506
576
  pred_i = y_pred[:, i] # Use passed-in y_pred
507
577
  prob_i = y_prob[:, i] # Use passed-in y_prob
@@ -537,7 +607,7 @@ def multi_label_classification_metrics(
537
607
  ax_cm.tick_params(axis='y', labelsize=ytick_size)
538
608
 
539
609
  # Set titles and labels with padding
540
- ax_cm.set_title(f"Confusion Matrix for '{name}'", pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
610
+ ax_cm.set_title(f"Confusion Matrix - {name}", pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
541
611
  ax_cm.set_xlabel(ax_cm.get_xlabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
542
612
  ax_cm.set_ylabel(ax_cm.get_ylabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
543
613
 
@@ -594,7 +664,7 @@ def multi_label_classification_metrics(
594
664
  ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=format_config.ROC_PR_line) # Use config color
595
665
  ax_roc.plot([0, 1], [0, 1], 'k--')
596
666
 
597
- ax_roc.set_title(f'ROC Curve for "{name}"', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
667
+ ax_roc.set_title(f'ROC Curve - {name}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
598
668
  ax_roc.set_xlabel('False Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
599
669
  ax_roc.set_ylabel('True Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
600
670
 
@@ -616,7 +686,7 @@ def multi_label_classification_metrics(
616
686
  ap_score = average_precision_score(true_i, prob_i)
617
687
  fig_pr, ax_pr = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
618
688
  ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=format_config.ROC_PR_line) # Use config color
619
- ax_pr.set_title(f'Precision-Recall Curve for "{name}"', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
689
+ ax_pr.set_title(f'PR Curve - {name}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
620
690
  ax_pr.set_xlabel('Recall', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
621
691
  ax_pr.set_ylabel('Precision', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
622
692
 
@@ -654,15 +724,20 @@ def multi_label_classification_metrics(
654
724
  # Calculate calibration curve for this specific label
655
725
  prob_true, prob_pred = calibration_curve(true_i, prob_i, n_bins=dynamic_bins)
656
726
 
727
+ # Anchor the plot to (0,0) and (1,1)
728
+ prob_true = np.concatenate(([0.0], prob_true, [1.0]))
729
+ prob_pred = np.concatenate(([0.0], prob_pred, [1.0]))
730
+
731
+ # Plot the calibration curve
657
732
  ax_cal.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')
658
733
  ax_cal.plot(prob_pred,
659
734
  prob_true,
660
735
  marker='o',
661
736
  linewidth=2,
662
- label=f"Calibration for '{name}'",
737
+ label=f"Model Calibration",
663
738
  color=format_config.ROC_PR_line)
664
739
 
665
- ax_cal.set_title(f'Reliability Curve for "{name}"', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
740
+ ax_cal.set_title(f'Calibration - {name}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
666
741
  ax_cal.set_xlabel('Mean Predicted Probability', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
667
742
  ax_cal.set_ylabel('Fraction of Positives', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
668
743
 
@@ -0,0 +1,41 @@
1
+ from ..keys._keys import _EvaluationConfig
2
+ from ..path_manager import sanitize_filename
3
+ from .._core import get_logger
4
+
5
+
6
+ _LOGGER = get_logger("Metrics Helper")
7
+
8
+
9
+ def check_and_abbreviate_name(name: str) -> str:
10
+ """
11
+ Checks if a name exceeds the NAME_LIMIT. If it does, creates an abbreviation
12
+ (initials of words) or truncates it if the abbreviation is empty.
13
+
14
+ Args:
15
+ name (str): The original label or target name.
16
+
17
+ Returns:
18
+ str: The potentially abbreviated name.
19
+ """
20
+ limit = _EvaluationConfig.NAME_LIMIT
21
+
22
+ # Strip whitespace
23
+ name = name.strip()
24
+
25
+ if len(name) <= limit:
26
+ return name
27
+
28
+ # Attempt abbreviation: First letter of each word (split by space or underscore)
29
+ parts = [w for w in name.replace("_", " ").split() if w]
30
+ abbr = "".join(p[0].upper() for p in parts)
31
+
32
+ # Keep only alphanumeric characters
33
+ abbr = "".join(ch for ch in abbr if ch.isalnum())
34
+
35
+ # Fallback if abbreviation failed or is empty
36
+ if not abbr:
37
+ sanitized = sanitize_filename(name)
38
+ abbr = sanitized[:limit]
39
+
40
+ _LOGGER.warning(f"Label '{name}' is too long. Abbreviating to '{abbr}'.")
41
+ return abbr
@@ -19,6 +19,8 @@ from ..path_manager import make_fullpath, sanitize_filename
19
19
  from .._core import get_logger
20
20
  from ..keys._keys import _EvaluationConfig
21
21
 
22
+ from ._helpers import check_and_abbreviate_name
23
+
22
24
 
23
25
  _LOGGER = get_logger("Regression Metrics")
24
26
 
@@ -180,6 +182,9 @@ def multi_target_regression_metrics(
180
182
  if y_true.shape[1] != len(target_names):
181
183
  _LOGGER.error("Number of target names must match the number of columns in y_true.")
182
184
  raise ValueError()
185
+
186
+ # --- Pre-process target names for abbreviation ---
187
+ target_names = [check_and_abbreviate_name(name) for name in target_names]
183
188
 
184
189
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
185
190
  metrics_summary = []
@@ -2,6 +2,7 @@ from ._analysis import (
2
2
  summarize_dataframe,
3
3
  show_null_columns,
4
4
  match_and_filter_columns_by_regex,
5
+ check_class_balance,
5
6
  )
6
7
 
7
8
  from ._cleaning import (
@@ -28,6 +29,7 @@ from ._features import (
28
29
  split_continuous_binary,
29
30
  split_continuous_categorical_targets,
30
31
  encode_categorical_features,
32
+ encode_classification_target,
31
33
  reconstruct_one_hot,
32
34
  reconstruct_binary,
33
35
  reconstruct_multibinary,
@@ -44,7 +46,6 @@ from .._core import _imprimir_disponibles
44
46
 
45
47
  __all__ = [
46
48
  "summarize_dataframe",
47
- "show_null_columns",
48
49
  "drop_constant_columns",
49
50
  "drop_rows_with_missing_data",
50
51
  "drop_columns_with_missing_data",
@@ -61,10 +62,13 @@ __all__ = [
61
62
  "plot_categorical_vs_target",
62
63
  "plot_correlation_heatmap",
63
64
  "encode_categorical_features",
65
+ "encode_classification_target",
64
66
  "finalize_feature_schema",
65
67
  "apply_feature_schema",
66
68
  "reconstruct_from_schema",
67
69
  "match_and_filter_columns_by_regex",
70
+ "show_null_columns",
71
+ "check_class_balance",
68
72
  "standardize_percentages",
69
73
  "reconstruct_one_hot",
70
74
  "reconstruct_binary",
@@ -16,6 +16,7 @@ __all__ = [
16
16
  "summarize_dataframe",
17
17
  "show_null_columns",
18
18
  "match_and_filter_columns_by_regex",
19
+ "check_class_balance",
19
20
  ]
20
21
 
21
22
 
@@ -212,3 +213,151 @@ def match_and_filter_columns_by_regex(
212
213
 
213
214
  return filtered_df, matched_columns
214
215
 
216
+
217
+ def check_class_balance(
218
+ df: pd.DataFrame,
219
+ target: Union[str, list[str]],
220
+ plot_to_dir: Optional[Union[str, Path]] = None,
221
+ plot_filename: str = "Class_Balance"
222
+ ) -> pd.DataFrame:
223
+ """
224
+ Analyzes the class balance for classification targets.
225
+
226
+ Handles two cases:
227
+ 1. Single Column (Binary/Multi-class): Calculates frequency of each unique value.
228
+ 2. List of Columns (Multi-label Binary): Calculates the frequency of positive values (1) per column.
229
+
230
+ Args:
231
+ df (pd.DataFrame): The input DataFrame.
232
+ target (str | list[str]): The target column name (for single/multi-class classification)
233
+ or list of column names (for multi-label-binary classification).
234
+ plot_to_dir (str | Path | None): Directory to save the balance plot.
235
+ plot_filename (str): Filename for the plot (without extension).
236
+
237
+ Returns:
238
+ pd.DataFrame: Summary table of counts and percentages.
239
+ """
240
+ # Early fail for empty DataFrame and handle list of targets with only one item
241
+ if df.empty:
242
+ _LOGGER.error("Input DataFrame is empty.")
243
+ raise ValueError()
244
+
245
+ if isinstance(target, list):
246
+ if len(target) == 0:
247
+ _LOGGER.error("Target list is empty.")
248
+ raise ValueError()
249
+ elif len(target) == 1:
250
+ target = target[0] # Simplify to single column case
251
+
252
+ # Case 1: Single Target (Binary or Multi-class)
253
+ if isinstance(target, str):
254
+ if target not in df.columns:
255
+ _LOGGER.error(f"Target column '{target}' not found in DataFrame.")
256
+ raise ValueError()
257
+
258
+ # Calculate stats
259
+ counts = df[target].value_counts(dropna=False).sort_index()
260
+ percents = df[target].value_counts(normalize=True, dropna=False).sort_index() * 100
261
+
262
+ summary = pd.DataFrame({
263
+ 'Count': counts,
264
+ 'Percentage': percents.round(2)
265
+ })
266
+ summary.index.name = "Class"
267
+
268
+ # Plotting
269
+ if plot_to_dir:
270
+ try:
271
+ save_path = make_fullpath(plot_to_dir, make=True, enforce="directory")
272
+
273
+ plt.figure(figsize=(10, 6))
274
+ # Convert index to str to handle numeric classes cleanly on x-axis
275
+ x_labels = summary.index.astype(str)
276
+ bars = plt.bar(x_labels, summary['Count'], color='lightgreen', edgecolor='black', alpha=0.7)
277
+
278
+ plt.title(f"Class Balance: {target}")
279
+ plt.xlabel(target)
280
+ plt.ylabel("Count")
281
+ plt.grid(axis='y', linestyle='--', alpha=0.5)
282
+
283
+ # Add percentage labels on top of bars
284
+ for bar, pct in zip(bars, summary['Percentage']):
285
+ height = bar.get_height()
286
+ plt.text(bar.get_x() + bar.get_width()/2, height,
287
+ f'{pct:.1f}%', ha='center', va='bottom', fontsize=10)
288
+
289
+ plt.tight_layout()
290
+ full_filename = sanitize_filename(plot_filename) + ".svg"
291
+ plt.savefig(save_path / full_filename, format='svg', bbox_inches="tight")
292
+ plt.close()
293
+ _LOGGER.info(f"Saved class balance plot: '{full_filename}'")
294
+ except Exception as e:
295
+ _LOGGER.error(f"Failed to plot class balance. Error: {e}")
296
+ plt.close()
297
+
298
+ return summary
299
+
300
+ # Case 2: Multi-label (List of binary columns)
301
+ elif isinstance(target, list):
302
+ missing_cols = [t for t in target if t not in df.columns]
303
+ if missing_cols:
304
+ _LOGGER.error(f"Target columns not found: {missing_cols}")
305
+ raise ValueError()
306
+
307
+ stats = []
308
+ for col in target:
309
+ # Assume 0/1 or False/True. Sum gives the count of positives.
310
+ # We enforce numeric to be safe
311
+ try:
312
+ numeric_series = pd.to_numeric(df[col], errors='coerce').fillna(0)
313
+ pos_count = numeric_series.sum()
314
+ total_count = len(df)
315
+ pct = (pos_count / total_count) * 100
316
+ except Exception:
317
+ _LOGGER.warning(f"Column '{col}' could not be processed as numeric. Assuming 0 positives.")
318
+ pos_count = 0
319
+ pct = 0.0
320
+
321
+ stats.append({
322
+ 'Label': col,
323
+ 'Positive_Count': int(pos_count),
324
+ 'Positive_Percentage': round(pct, 2)
325
+ })
326
+
327
+ summary = pd.DataFrame(stats).set_index("Label").sort_values("Positive_Percentage", ascending=True)
328
+
329
+ # Plotting
330
+ if plot_to_dir:
331
+ try:
332
+ save_path = make_fullpath(plot_to_dir, make=True, enforce="directory")
333
+
334
+ # Dynamic height for many labels
335
+ height = max(6, len(target) * 0.4)
336
+ plt.figure(figsize=(10, height))
337
+
338
+ bars = plt.barh(summary.index, summary['Positive_Percentage'], color='lightgreen', edgecolor='black', alpha=0.7)
339
+
340
+ plt.title(f"Multi-label Binary Class Balance")
341
+ plt.xlabel("Positive Class Percentage (%)")
342
+ plt.xlim(0, 100)
343
+ plt.grid(axis='x', linestyle='--', alpha=0.5)
344
+
345
+ # Add count labels at the end of bars
346
+ for bar, count in zip(bars, summary['Positive_Count']):
347
+ width = bar.get_width()
348
+ plt.text(width + 1, bar.get_y() + bar.get_height()/2, f'{width:.1f}%', ha='left', va='center', fontsize=9)
349
+
350
+ plt.tight_layout()
351
+ full_filename = sanitize_filename(plot_filename) + ".svg"
352
+ plt.savefig(save_path / full_filename, format='svg', bbox_inches="tight")
353
+ plt.close()
354
+ _LOGGER.info(f"Saved multi-label balance plot: '{full_filename}'")
355
+ except Exception as e:
356
+ _LOGGER.error(f"Failed to plot class balance. Error: {e}")
357
+ plt.close()
358
+
359
+ return summary.sort_values("Positive_Percentage", ascending=False)
360
+
361
+ else:
362
+ _LOGGER.error("Target must be a string or a list of strings.")
363
+ raise TypeError()
@@ -3,7 +3,10 @@ from pandas.api.types import is_numeric_dtype, is_object_dtype
3
3
  import numpy as np
4
4
  from typing import Any, Optional, Union
5
5
  import re
6
+ import json
7
+ from pathlib import Path
6
8
 
9
+ from ..path_manager import make_fullpath
7
10
  from .._core import get_logger
8
11
 
9
12
 
@@ -15,6 +18,7 @@ __all__ = [
15
18
  "split_continuous_binary",
16
19
  "split_continuous_categorical_targets",
17
20
  "encode_categorical_features",
21
+ "encode_classification_target",
18
22
  "reconstruct_one_hot",
19
23
  "reconstruct_binary",
20
24
  "reconstruct_multibinary",
@@ -263,6 +267,78 @@ def encode_categorical_features(
263
267
  return df_encoded, mappings
264
268
 
265
269
 
270
+ def encode_classification_target(
271
+ df: pd.DataFrame,
272
+ target_col: str,
273
+ save_dir: Union[str, Path],
274
+ verbose: int = 2
275
+ ) -> tuple[pd.DataFrame, dict[str, int]]:
276
+ """
277
+ Encodes a target classification column into integers (0, 1, 2...) and saves the mapping to a JSON file.
278
+
279
+ This ensures that the target variable is in the correct numeric format for training
280
+ and provides a persistent artifact (the JSON file) to map predictions back to labels later.
281
+
282
+ Args:
283
+ df (pd.DataFrame): Input DataFrame.
284
+ target_col (str): Name of the target column to encode.
285
+ save_dir (str | Path): Directory where the class map JSON will be saved.
286
+ verbose (int): Verbosity level for logging.
287
+
288
+ Returns:
289
+ Tuple (Dataframe, Dict):
290
+ - A new DataFrame with the target column encoded as integers.
291
+ - The dictionary mapping original labels (str) to integers (int).
292
+ """
293
+ if target_col not in df.columns:
294
+ _LOGGER.error(f"Target column '{target_col}' not found in DataFrame.")
295
+ raise ValueError()
296
+
297
+ # Validation: Check for missing values in target
298
+ if df[target_col].isnull().any():
299
+ n_missing = df[target_col].isnull().sum()
300
+ _LOGGER.error(f"Target column '{target_col}' contains {n_missing} missing values. Please handle them before encoding.")
301
+ raise ValueError()
302
+
303
+ # Ensure directory exists
304
+ save_path = make_fullpath(save_dir, make=True, enforce="directory")
305
+ file_path = save_path / "class_map.json"
306
+
307
+ # Get unique values and sort them to ensure deterministic encoding (0, 1, 2...)
308
+ # Convert to string to ensure the keys in JSON are strings
309
+ unique_labels = sorted(df[target_col].astype(str).unique())
310
+
311
+ # Create mapping: { Label -> Integer }
312
+ class_map = {label: idx for idx, label in enumerate(unique_labels)}
313
+
314
+ # Apply mapping
315
+ # cast column to string to match the keys in class_map
316
+ df_encoded = df.copy()
317
+ df_encoded[target_col] = df_encoded[target_col].astype(str).map(class_map)
318
+
319
+ # Save to JSON
320
+ try:
321
+ with open(file_path, 'w', encoding='utf-8') as f:
322
+ json.dump(class_map, f, indent=4)
323
+
324
+ if verbose >= 2:
325
+ _LOGGER.info(f"Class mapping saved to: '{file_path}'")
326
+
327
+ if verbose >= 3:
328
+ _LOGGER.info(f"Target '{target_col}' encoded with {len(class_map)} classes.")
329
+ # Print a preview
330
+ if len(class_map) <= 10:
331
+ print(f" Mapping: {class_map}")
332
+ else:
333
+ print(f" Mapping (first 5): {dict(list(class_map.items())[:5])} ...")
334
+
335
+ except Exception as e:
336
+ _LOGGER.error(f"Failed to save class map JSON. Error: {e}")
337
+ raise IOError()
338
+
339
+ return df_encoded, class_map
340
+
341
+
266
342
  def reconstruct_one_hot(
267
343
  df: pd.DataFrame,
268
344
  features_to_reconstruct: list[Union[str, tuple[str, Optional[str]]]],
ml_tools/keys/_keys.py CHANGED
@@ -306,6 +306,7 @@ class _EvaluationConfig:
306
306
  LOSS_PLOT_LEGEND_SIZE = 24
307
307
  # CM settings
308
308
  CM_SIZE = (9, 8) # used for multi label binary classification confusion matrix
309
+ NAME_LIMIT = 15 # max number of characters for feature/label names in plots
309
310
 
310
311
  class _OneHotOtherPlaceholder:
311
312
  """Used internally by GUI_tools."""
@@ -0,0 +1,19 @@
1
+ from ._single_resampling import (
2
+ DragonResampler,
3
+ )
4
+
5
+ from ._multi_resampling import (
6
+ DragonMultiResampler,
7
+ )
8
+
9
+ from .._core import _imprimir_disponibles
10
+
11
+
12
+ __all__ = [
13
+ "DragonResampler",
14
+ "DragonMultiResampler",
15
+ ]
16
+
17
+
18
+ def info():
19
+ _imprimir_disponibles(__all__)
@@ -0,0 +1,49 @@
1
+ import polars as pl
2
+ import pandas as pd
3
+ from typing import Union
4
+ from abc import ABC, abstractmethod
5
+
6
+
7
+ __all__ = ["_DragonBaseResampler"]
8
+
9
+
10
+ class _DragonBaseResampler(ABC):
11
+ """
12
+ Base class for Dragon resamplers handling common I/O and state.
13
+ """
14
+ def __init__(self,
15
+ return_pandas: bool = False,
16
+ seed: int = 42):
17
+ self.return_pandas = return_pandas
18
+ self.seed = seed
19
+
20
+ def _convert_to_polars(self, df: Union[pd.DataFrame, pl.DataFrame]) -> pl.DataFrame:
21
+ """Standardizes input to Polars DataFrame."""
22
+ if isinstance(df, pd.DataFrame):
23
+ return pl.from_pandas(df)
24
+ return df
25
+
26
+ def _convert_to_pandas(self, df: pl.DataFrame) -> pd.DataFrame:
27
+ """Converts Polars DataFrame back to Pandas."""
28
+ return df.to_pandas(use_pyarrow_extension_array=False)
29
+
30
+ def _process_return(self, df: pl.DataFrame, shuffle: bool = True) -> Union[pd.DataFrame, pl.DataFrame]:
31
+ """
32
+ Finalizes the DataFrame:
33
+ 1. Global Shuffle (optional but recommended for ML).
34
+ 2. Conversion to Pandas (if requested).
35
+ """
36
+ if shuffle:
37
+ # Random shuffle of the final dataset
38
+ df = df.sample(fraction=1.0, seed=self.seed, with_replacement=False)
39
+
40
+ if self.return_pandas:
41
+ return self._convert_to_pandas(df)
42
+ return df
43
+
44
+ @abstractmethod
45
+ def describe_balance(self, df: Union[pd.DataFrame, pl.DataFrame], top_n: int = 10) -> None:
46
+ """
47
+ Prints a statistical summary of the target distribution.
48
+ """
49
+ pass
@@ -0,0 +1,184 @@
1
+ import polars as pl
2
+ import pandas as pd
3
+ import numpy as np
4
+ from typing import Union, Optional
5
+
6
+ from .._core import get_logger
7
+
8
+ from ._base_resampler import _DragonBaseResampler
9
+
10
+
11
+ _LOGGER = get_logger("DragonMultiResampler")
12
+
13
+
14
+ __all__ = [
15
+ "DragonMultiResampler",
16
+ ]
17
+
18
+
19
+ class DragonMultiResampler(_DragonBaseResampler):
20
+ """
21
+ A robust resampler for multi-label binary classification tasks using Polars.
22
+
23
+ It provides methods to downsample "all-negative" rows and balance the dataset
24
+ based on unique label combinations (Powerset).
25
+ """
26
+ def __init__(self,
27
+ target_columns: list[str],
28
+ return_pandas: bool = False,
29
+ seed: int = 42):
30
+ """
31
+ Args:
32
+ target_columns (List[str]): The list of binary target column names.
33
+ return_pandas (bool): Whether to return results as pandas DataFrame.
34
+ seed (int): Random seed for reproducibility.
35
+ """
36
+ super().__init__(return_pandas=return_pandas, seed=seed)
37
+ self.targets = target_columns
38
+
39
+ def downsample_all_negatives(self,
40
+ df: Union[pd.DataFrame, pl.DataFrame],
41
+ negative_ratio: float = 1.0,
42
+ verbose: int = 2) -> Union[pd.DataFrame, pl.DataFrame]:
43
+ """
44
+ Downsamples rows where ALL target columns are 0 ("background" class).
45
+
46
+ Args:
47
+ df (pd.DataFrame | pl.DataFrame): Input DataFrame.
48
+ negative_ratio (float): Ratio of negatives to positives to retain.
49
+ verbose (int): Verbosity level for logging.
50
+
51
+ Returns:
52
+ Dataframe: Resampled DataFrame.
53
+ """
54
+ df_pl = self._convert_to_polars(df)
55
+
56
+ # 1. Identify "All Negative" vs "Has Signal"
57
+ fold_expr = pl.sum_horizontal(pl.col(self.targets)).cast(pl.UInt32)
58
+
59
+ df_pos = df_pl.filter(fold_expr > 0)
60
+ df_neg = df_pl.filter(fold_expr == 0)
61
+
62
+ n_pos = df_pos.height
63
+ n_neg_original = df_neg.height
64
+
65
+ if n_pos == 0:
66
+ if verbose >= 1:
67
+ _LOGGER.warning("No positive cases found in any label. Returning original DataFrame.")
68
+ return self._process_return(df_pl, shuffle=False)
69
+
70
+ # 2. Calculate target count for negatives
71
+ target_n_neg = int(n_pos * negative_ratio)
72
+
73
+ # 3. Sample if necessary
74
+ if n_neg_original > target_n_neg:
75
+ if verbose >= 2:
76
+ _LOGGER.info(f"📉 Downsampling 'All-Negative' rows from {n_neg_original} to {target_n_neg}")
77
+
78
+ # Here we use standard sampling because we are not grouping
79
+ df_neg_sampled = df_neg.sample(n=target_n_neg, seed=self.seed, with_replacement=False)
80
+ df_resampled = pl.concat([df_pos, df_neg_sampled])
81
+
82
+ return self._process_return(df_resampled)
83
+ else:
84
+ if verbose >= 1:
85
+ _LOGGER.warning(f"Negative count ({n_neg_original}) is already below target ({target_n_neg}). No downsampling applied.")
86
+ return self._process_return(df_pl, shuffle=False)
87
+
88
+ def balance_powerset(self,
89
+ df: Union[pd.DataFrame, pl.DataFrame],
90
+ max_samples_per_combination: Optional[int] = None,
91
+ quantile_limit: float = 0.90,
92
+ verbose: int = 2) -> Union[pd.DataFrame, pl.DataFrame]:
93
+ """
94
+ Groups data by unique label combinations (Powerset) and downsamples
95
+ majority combinations.
96
+
97
+ Args:
98
+ df (pd.DataFrame | pl.DataFrame): Input DataFrame.
99
+ max_samples_per_combination (int | None): Fixed cap per combination.
100
+ If None, uses quantile_limit to determine cap.
101
+ quantile_limit (float): Quantile to determine cap if max_samples_per_combination is None.
102
+ verbose (int): Verbosity level for logging.
103
+
104
+ Returns:
105
+ Dataframe: Resampled DataFrame.
106
+ """
107
+ df_pl = self._convert_to_polars(df)
108
+
109
+ # 1. Create a hash/structural representation of the targets for grouping
110
+ df_lazy = df_pl.lazy().with_columns(
111
+ pl.concat_list(pl.col(self.targets)).alias("_powerset_key")
112
+ )
113
+
114
+ # 2. Calculate frequencies
115
+ # We need to collect partially to calculate the quantile cap
116
+ combo_counts = df_lazy.group_by("_powerset_key").len().collect()
117
+
118
+ # Determine the Cap
119
+ if max_samples_per_combination is None:
120
+ # Handle potential None from quantile (satisfies linter)
121
+ q_val = combo_counts["len"].quantile(quantile_limit)
122
+
123
+ if q_val is None:
124
+ if verbose >= 1:
125
+ _LOGGER.warning("Data empty or insufficient to calculate quantile. Returning original.")
126
+ return self._process_return(df_pl, shuffle=False)
127
+
128
+ cap_size = int(q_val)
129
+
130
+ if verbose >= 3:
131
+ _LOGGER.info(f"📊 Auto-calculated Powerset Cap: {cap_size} samples (based on {quantile_limit} quantile).")
132
+ else:
133
+ cap_size = max_samples_per_combination
134
+
135
+ # 3. Apply Stratified Sampling / Capping (Randomized)
136
+ df_balanced = (
137
+ df_lazy
138
+ .filter(
139
+ pl.int_range(0, pl.len())
140
+ .shuffle(seed=self.seed)
141
+ .over("_powerset_key")
142
+ < cap_size
143
+ )
144
+ .drop("_powerset_key")
145
+ .collect()
146
+ )
147
+
148
+ if verbose >= 2:
149
+ original_count = df_pl.height
150
+ new_count = df_balanced.height
151
+ _LOGGER.info(f"⚖️ Powerset Balancing: Reduced from {original_count} to {new_count} rows.")
152
+
153
+ return self._process_return(df_balanced)
154
+
155
+ def describe_balance(self, df: Union[pd.DataFrame, pl.DataFrame], top_n: int = 10) -> None:
156
+ df_pl = self._convert_to_polars(df)
157
+ total_rows = df_pl.height
158
+
159
+ message_1 = f"\n📊 --- Target Balance Report ({total_rows} samples) ---\n🎯 Multi-Targets: {len(self.targets)} columns"
160
+
161
+ # A. Individual Label Counts
162
+ sums = df_pl.select([
163
+ pl.sum(col).alias(col) for col in self.targets
164
+ ]).transpose(include_header=True, header_name="Label", column_names=["Count"])
165
+
166
+ sums = sums.with_columns(
167
+ (pl.col("Count") / total_rows * 100).round(2).alias("Percentage(%)")
168
+ ).sort("Count", descending=True)
169
+
170
+ message_1 += "\n🔹 Individual Label Frequencies:"
171
+
172
+ # B. Powerset (Combination) Counts
173
+ message_2 = f"🔹 Top {top_n} Label Combinations (Powerset):"
174
+
175
+ combo_stats = (
176
+ df_pl.group_by(self.targets)
177
+ .len(name="Count")
178
+ .sort("Count", descending=True)
179
+ .with_columns(
180
+ (pl.col("Count") / total_rows * 100).round(2).alias("Percentage(%)")
181
+ )
182
+ )
183
+
184
+ _LOGGER.info(f"{message_1}\n{sums.head(top_n)}\n{message_2}\n{combo_stats.head(top_n)}")
@@ -0,0 +1,113 @@
1
+ import polars as pl
2
+ import pandas as pd
3
+ import numpy as np
4
+ from typing import Union
5
+
6
+ from .._core import get_logger
7
+
8
+ from ._base_resampler import _DragonBaseResampler
9
+
10
+
11
+ _LOGGER = get_logger("DragonResampler")
12
+
13
+
14
+ __all__ = [
15
+ "DragonResampler",
16
+ ]
17
+
18
+
19
+ class DragonResampler(_DragonBaseResampler):
20
+ """
21
+ A resampler for Single-Target Classification tasks (Binary or Multiclass).
22
+
23
+ It balances classes by downsampling majority classes relative to the size of the minority class.
24
+ """
25
+ def __init__(self,
26
+ target_column: str,
27
+ return_pandas: bool = False,
28
+ seed: int = 42):
29
+ """
30
+ Args:
31
+ target_column (str): The name of the single target column.
32
+ return_pandas (bool): Whether to return results as pandas DataFrame.
33
+ seed (int): Random seed for reproducibility.
34
+ """
35
+ super().__init__(return_pandas=return_pandas, seed=seed)
36
+ self.target = target_column
37
+
38
+ def balance_classes(self,
39
+ df: Union[pd.DataFrame, pl.DataFrame],
40
+ majority_ratio: float = 1.0,
41
+ verbose: int = 2) -> Union[pd.DataFrame, pl.DataFrame]:
42
+ """
43
+ Downsamples all classes to match the minority class size (scaled by a ratio).
44
+ """
45
+ df_pl = self._convert_to_polars(df)
46
+
47
+ # 1. Calculate Class Counts
48
+ counts = df_pl.group_by(self.target).len().sort("len")
49
+
50
+ if counts.height == 0:
51
+ _LOGGER.error("DataFrame is empty or target column missing.")
52
+ return self._process_return(df_pl, shuffle=False)
53
+
54
+ # 2. Identify Statistics
55
+ min_val = counts["len"].min()
56
+ max_val = counts["len"].max()
57
+
58
+ if min_val is None or max_val is None:
59
+ _LOGGER.error("Failed to calculate class statistics (unexpected None).")
60
+ raise ValueError()
61
+
62
+ minority_count: int = min_val # type: ignore
63
+ majority_count: int = max_val # type: ignore
64
+
65
+ # Calculate the cap
66
+ cap_size = int(minority_count * majority_ratio)
67
+
68
+ if verbose >= 3:
69
+ _LOGGER.info(f"📊 Class Distribution:\n{counts}")
70
+ _LOGGER.info(f"🎯 Strategy: Cap majorities at {cap_size}")
71
+
72
+ # Optimization: If data is already balanced enough
73
+ if majority_count <= cap_size:
74
+ if verbose >= 2:
75
+ _LOGGER.info("Data is already within the requested balance ratio.")
76
+ return self._process_return(df_pl, shuffle=False)
77
+
78
+ # 3. Apply Downsampling (Randomized)
79
+ # We generate a random range index per group and filter by it.
80
+ # This ensures we pick a random subset, not the first N rows.
81
+ df_balanced = (
82
+ df_pl.lazy()
83
+ .filter(
84
+ pl.int_range(0, pl.len())
85
+ .shuffle(seed=self.seed)
86
+ .over(self.target)
87
+ < cap_size
88
+ )
89
+ .collect()
90
+ )
91
+
92
+ if verbose >= 2:
93
+ reduced_count = df_balanced.height
94
+ _LOGGER.info(f"⚖️ Balancing Complete: {df_pl.height} -> {reduced_count} rows.")
95
+
96
+ return self._process_return(df_balanced)
97
+
98
+ def describe_balance(self, df: Union[pd.DataFrame, pl.DataFrame], top_n: int = 10) -> None:
99
+ df_pl = self._convert_to_polars(df)
100
+ total_rows = df_pl.height
101
+
102
+ message = f"\n📊 --- Target Balance Report ({total_rows} samples) ---\n🎯 Single Target: '{self.target}'"
103
+
104
+ stats = (
105
+ df_pl.group_by(self.target)
106
+ .len(name="Count")
107
+ .sort("Count", descending=True)
108
+ .with_columns(
109
+ (pl.col("Count") / total_rows * 100).round(2).alias("Percentage(%)")
110
+ )
111
+ )
112
+
113
+ _LOGGER.info(f"{message}\n{stats.head(top_n)}")