dragon-ml-toolbox 20.7.1__py3-none-any.whl → 20.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 20.7.1
3
+ Version: 20.8.0
4
4
  Summary: Complete pipelines and helper tools for data science and machine learning projects.
5
5
  Author-email: Karl Luigi Loza Vidaurre <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -174,6 +174,7 @@ ML_vision_transformers
174
174
  optimization_tools
175
175
  path_manager
176
176
  plot_fonts
177
+ resampling
177
178
  schema
178
179
  serde
179
180
  SQL
@@ -206,6 +207,7 @@ optimization_tools
206
207
  path_manager
207
208
  plot_fonts
208
209
  PSO_optimization
210
+ resampling
209
211
  schema
210
212
  serde
211
213
  SQL
@@ -1,5 +1,5 @@
1
- dragon_ml_toolbox-20.7.1.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
2
- dragon_ml_toolbox-20.7.1.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=0-HBRMMgKuwtGy6nMJZvIn1fLxhx_ksyyVB2U_iyYZU,2818
1
+ dragon_ml_toolbox-20.8.0.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
2
+ dragon_ml_toolbox-20.8.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=0-HBRMMgKuwtGy6nMJZvIn1fLxhx_ksyyVB2U_iyYZU,2818
3
3
  ml_tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  ml_tools/constants.py,sha256=3br5Rk9cL2IUo638eJuMOGdbGQaWssaUecYEvSeRBLM,3322
5
5
  ml_tools/ETL_cleaning/__init__.py,sha256=gLRHF-qzwpqKTvbbn9chIQELeUDh_XGpBRX28j-5IqI,545
@@ -30,7 +30,7 @@ ml_tools/ML_chain/_update_schema.py,sha256=z1Us7lv6hy6GwSu1mcid50Jmqq3sh91hMQ0Ln
30
30
  ml_tools/ML_configuration/__init__.py,sha256=ogktFnYxz5jWJkhHS4DVaMldHkt3lT2gw9jx5PQ3d78,2755
31
31
  ml_tools/ML_configuration/_base_model_config.py,sha256=95L3IfobNFMtnNr79zYpDGerC1q1v7M05tWZvTS2cwE,2247
32
32
  ml_tools/ML_configuration/_finalize.py,sha256=l_n13bLu0avMdJ8hNRrH8V_wOBQZM1UGsTydKBkTysM,15047
33
- ml_tools/ML_configuration/_metrics.py,sha256=xKtEKzphtidwwU8UuUpGv4B8Y6Bv0tAOjEFUYfz8Ehc,23758
33
+ ml_tools/ML_configuration/_metrics.py,sha256=KJM7HQeoEmJUUUrxNa4wYf2N9NawGPJoy7AGdNO3gxQ,24059
34
34
  ml_tools/ML_configuration/_models.py,sha256=lvuuqvD6DWUzOa3i06NZfrdfOi9bu2e26T_QO6BGMSw,7629
35
35
  ml_tools/ML_configuration/_training.py,sha256=_M_TwouHFNbGrZQtQNAvyG_poSVpmN99cbyUonZsHhk,8969
36
36
  ml_tools/ML_datasetmaster/__init__.py,sha256=UltQzuXnlXVCkD-aeA5TW4IcMVLnQf1_aglawg4WyrI,580
@@ -39,7 +39,7 @@ ml_tools/ML_datasetmaster/_datasetmaster.py,sha256=Oy2UE3YJpKTaFwQF5TkQLgLB54-BF
39
39
  ml_tools/ML_datasetmaster/_sequence_datasetmaster.py,sha256=cW3fuILZWs-7Yuo4T2fgGfTC4vwho3Gp4ohIKJYS7O0,18452
40
40
  ml_tools/ML_datasetmaster/_vision_datasetmaster.py,sha256=kvSqXYeNBN1JSRfSEEXYeIcsqy9HsJAl_EwFWClqlsw,67025
41
41
  ml_tools/ML_evaluation/__init__.py,sha256=e3c8JNP0tt4Kxc7QSQpGcOgrxf8JAucH4UkJvJxUL2E,1122
42
- ml_tools/ML_evaluation/_classification.py,sha256=8bKQejKrgMipnxU1T12ted7p60xvJS0d0MvHtdNBCBM,30971
42
+ ml_tools/ML_evaluation/_classification.py,sha256=Te5ckLfBCUyb3QO9vZ_mlJF5wS5LoajXC54k1Fkct-U,33938
43
43
  ml_tools/ML_evaluation/_feature_importance.py,sha256=mTwi3LKom_axu6UFKunELj30APDdhG9GQC2w7I9mYhI,17137
44
44
  ml_tools/ML_evaluation/_loss.py,sha256=1a4O25i3Ya_3naNZNL7ELLUL46BY86g1scA7d7q2UFM,3625
45
45
  ml_tools/ML_evaluation/_regression.py,sha256=hnT2B2_6AnQ7aA7uk-X2lZL9G5JFGCduDXyZbr1gFCA,11037
@@ -103,10 +103,10 @@ ml_tools/_core/__init__.py,sha256=m-VP0RW0tOTm9N5NI3kFNcpM7WtVgs0RK9pK3ZJRZQQ,14
103
103
  ml_tools/_core/_logger.py,sha256=xzhn_FouMDRVNwXGBGlPC9Ruq6i5uCrmNaS5jesguMU,4972
104
104
  ml_tools/_core/_schema_load_ops.py,sha256=KLs9vBzANz5ESe2wlP-C41N4VlgGil-ywcfvWKSOGss,1551
105
105
  ml_tools/_core/_script_info.py,sha256=LtFGt10gEvCnhIRMKJPi2yXkiGLcdr7lE-oIP2XGHzQ,234
106
- ml_tools/data_exploration/__init__.py,sha256=nYKg1bPBgXibC5nhmNKPw3VaKFeVtlNGL_YpHixW-Pg,1795
107
- ml_tools/data_exploration/_analysis.py,sha256=H6LryV56FFCHWjvQdkhZbtprZy6aP8EqU_hC2Cf9CLE,7832
106
+ ml_tools/data_exploration/__init__.py,sha256=efUBsruHL56B429tUadl3PdG73zAF639Y430uMQRfko,1917
107
+ ml_tools/data_exploration/_analysis.py,sha256=PJNrEBz5ZZXHoUlQ6fh9Y86nzPQrLpVPv2Ye4NfOxgs,14181
108
108
  ml_tools/data_exploration/_cleaning.py,sha256=pAZOXgGK35j7O8q6cnyTwYK1GLNnD04A8p2fSyMB1mg,20906
109
- ml_tools/data_exploration/_features.py,sha256=wW-M8n2aLIy05DR2z4fI8wjpPjn3mOAnm9aSGYbMKwI,23363
109
+ ml_tools/data_exploration/_features.py,sha256=Z1noJfDxBzFRfusFp6NlpLF2NItuZuzFHq4ssWFqny4,26273
110
110
  ml_tools/data_exploration/_plotting.py,sha256=zH1dPcIoAlOuww23xIoBCsQOAshPPv9OyGposOA2RvI,19883
111
111
  ml_tools/data_exploration/_schema_ops.py,sha256=Fd6fBGGv4OpxmJ1HG9pith6QL90z0tzssCvzkQxlEEQ,11083
112
112
  ml_tools/ensemble_evaluation/__init__.py,sha256=t4Gr8EGEk8RLatyc92-S0BzbQvdvodzoF-qDAH2qjVg,546
@@ -118,7 +118,7 @@ ml_tools/ensemble_learning/_ensemble_learning.py,sha256=MHDZBR20_nStlSSeThFI3bSu
118
118
  ml_tools/excel_handler/__init__.py,sha256=AaWM3n_dqBhJLTs3OEA57ex5YykKXNOwVCyHlVsdnqI,530
119
119
  ml_tools/excel_handler/_excel_handler.py,sha256=TODudmeQgDSdxUKzLfAzizs--VL-g8WxDOfQ4sgxxLs,13965
120
120
  ml_tools/keys/__init__.py,sha256=-0c2pmrhyfROc-oQpEjJGLBMhSagA3CyFijQaaqZRqU,399
121
- ml_tools/keys/_keys.py,sha256=lL9NlijxOEAhfDPPqK_wL3QhjalrYK_fWM-KNniSIOA,9308
121
+ ml_tools/keys/_keys.py,sha256=jBhw99SRTlBkb9EFMDLZA86_kaHT4YLxkljDYRCTarE,9389
122
122
  ml_tools/math_utilities/__init__.py,sha256=K7Obkkc4rPKj4EbRZf1BsXHfiCg7FXYv_aN9Yc2Z_Vg,400
123
123
  ml_tools/math_utilities/_math_utilities.py,sha256=BYHIVcM9tuKIhVrkgLLiM5QalJ39zx7dXYy_M9aGgiM,9012
124
124
  ml_tools/optimization_tools/__init__.py,sha256=KD8JXpfGuPndO4AHnjJGu6uV1GRwhOfboD0KZV45kzw,658
@@ -129,6 +129,10 @@ ml_tools/path_manager/_dragonmanager.py,sha256=q9wHTKPmdzywEz6N14ipUoeR3MmW0bzB4
129
129
  ml_tools/path_manager/_path_tools.py,sha256=LcZE31QlkzZWUR8g1MW_N_mPY2DpKBJLA45VJz7ZYsw,11905
130
130
  ml_tools/plot_fonts/__init__.py,sha256=KIxXRCjQ3SliEoLhEcqs7zDVZbVTn38bmSdL-yR1Q2w,187
131
131
  ml_tools/plot_fonts/_plot_fonts.py,sha256=mfjXNT9P59ymHoTI85Q8CcvfxfK5BIFBWtTZH-hNIC4,2209
132
+ ml_tools/resampling/__init__.py,sha256=WB1YlNQgOIdSSQn-9eCIaiB0AHLSCkziFufqa-1QBG0,278
133
+ ml_tools/resampling/_base_resampler.py,sha256=8IqkEJ7uiAiC9bqbKfsC-5vIvrN3EwH7lLVDlRKQzM8,1617
134
+ ml_tools/resampling/_multi_resampling.py,sha256=m_iVvXPAu3p_EoBt2VZpgjhPLY1LmKa8fGtQo5E0pWk,7199
135
+ ml_tools/resampling/_single_resampling.py,sha256=zKL4Br7Lm4Jq90X-ewQ6AKTsP923bq9RIMnTxIxtXBc,3896
132
136
  ml_tools/schema/__init__.py,sha256=K6uiZ9f0GCQ7etw1yl2-dQVLhU7RkL3KHesO3HNX6v4,334
133
137
  ml_tools/schema/_feature_schema.py,sha256=MuPf6Nf7tDhUTGyX7tcFHZh-lLSNsJkLmlf9IxdF4O4,9660
134
138
  ml_tools/schema/_gui_schema.py,sha256=IVwN4THAdFrvh2TpV4SFd_zlzMX3eioF-w-qcSVTndE,7245
@@ -138,7 +142,7 @@ ml_tools/utilities/__init__.py,sha256=h4lE3SQstg-opcQj6QSKhu-HkqSbmHExsWoM9vC5D9
138
142
  ml_tools/utilities/_translate.py,sha256=U8hRPa3PmTpIf9n9yR3gBGmp_hkcsjQLwjAHSHc0WHs,10325
139
143
  ml_tools/utilities/_utility_save_load.py,sha256=EFvFaTaHahDQWdJWZr-j7cHqRbG_Xrpc96228JhV-bs,16773
140
144
  ml_tools/utilities/_utility_tools.py,sha256=bN0J9d1S0W5wNzNntBWqDsJcEAK7-1OgQg3X2fwXns0,6918
141
- dragon_ml_toolbox-20.7.1.dist-info/METADATA,sha256=IB7aIajHgmlg0UvpBOjDfCiQWfNmM0G3NKSpiEvDlAs,7866
142
- dragon_ml_toolbox-20.7.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
143
- dragon_ml_toolbox-20.7.1.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
144
- dragon_ml_toolbox-20.7.1.dist-info/RECORD,,
145
+ dragon_ml_toolbox-20.8.0.dist-info/METADATA,sha256=EVzUhpCzHarcTicuqc_t4prSuJdXGuCppSX7wnIv1JY,7888
146
+ dragon_ml_toolbox-20.8.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
147
+ dragon_ml_toolbox-20.8.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
148
+ dragon_ml_toolbox-20.8.0.dist-info/RECORD,,
@@ -98,10 +98,11 @@ class _BaseMultiLabelFormat:
98
98
  cmap: str = "BuGn",
99
99
  ROC_PR_line: str='darkorange',
100
100
  calibration_bins: Union[int, Literal['auto']]='auto',
101
- font_size: int = 25,
102
- xtick_size: int=20,
103
- ytick_size: int=20,
104
- legend_size: int=23) -> None:
101
+ font_size: int = 26,
102
+ xtick_size: int=22,
103
+ ytick_size: int=22,
104
+ legend_size: int=26,
105
+ cm_font_size: int=26) -> None:
105
106
  """
106
107
  Initializes the formatting configuration for multi-label classification metrics.
107
108
 
@@ -127,6 +128,8 @@ class _BaseMultiLabelFormat:
127
128
 
128
129
  legend_size (int): Font size for plot legends.
129
130
 
131
+ cm_font_size (int): Font size for the confusion matrix.
132
+
130
133
  <br>
131
134
 
132
135
  ### [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
@@ -142,6 +145,7 @@ class _BaseMultiLabelFormat:
142
145
  self.xtick_size = xtick_size
143
146
  self.ytick_size = ytick_size
144
147
  self.legend_size = legend_size
148
+ self.cm_font_size = cm_font_size
145
149
 
146
150
  def __repr__(self) -> str:
147
151
  parts = [
@@ -151,7 +155,8 @@ class _BaseMultiLabelFormat:
151
155
  f"font_size={self.font_size}",
152
156
  f"xtick_size={self.xtick_size}",
153
157
  f"ytick_size={self.ytick_size}",
154
- f"legend_size={self.legend_size}"
158
+ f"legend_size={self.legend_size}",
159
+ f"cm_font_size={self.cm_font_size}"
155
160
  ]
156
161
  return f"{self.__class__.__name__}({', '.join(parts)})"
157
162
 
@@ -520,10 +525,11 @@ class FormatMultiLabelBinaryClassificationMetrics(_BaseMultiLabelFormat):
520
525
  cmap: str = "BuGn",
521
526
  ROC_PR_line: str='darkorange',
522
527
  calibration_bins: Union[int, Literal['auto']]='auto',
523
- font_size: int = 25,
524
- xtick_size: int=20,
525
- ytick_size: int=20,
526
- legend_size: int=23
528
+ font_size: int = 26,
529
+ xtick_size: int=22,
530
+ ytick_size: int=22,
531
+ legend_size: int=26,
532
+ cm_font_size: int=26
527
533
  ) -> None:
528
534
  super().__init__(cmap=cmap,
529
535
  ROC_PR_line=ROC_PR_line,
@@ -531,7 +537,8 @@ class FormatMultiLabelBinaryClassificationMetrics(_BaseMultiLabelFormat):
531
537
  font_size=font_size,
532
538
  xtick_size=xtick_size,
533
539
  ytick_size=ytick_size,
534
- legend_size=legend_size)
540
+ legend_size=legend_size,
541
+ cm_font_size=cm_font_size)
535
542
 
536
543
 
537
544
  # Segmentation
@@ -481,6 +481,10 @@ def multi_label_classification_metrics(
481
481
  ytick_size = format_config.ytick_size
482
482
  legend_size = format_config.legend_size
483
483
  base_font_size = format_config.font_size
484
+
485
+ # config font size for heatmap
486
+ cm_font_size = format_config.cm_font_size
487
+ cm_tick_size = cm_font_size - 4
484
488
 
485
489
  # --- Calculate and Save Overall Metrics (using y_pred) ---
486
490
  h_loss = hamming_loss(y_true, y_pred)
@@ -488,7 +492,7 @@ def multi_label_classification_metrics(
488
492
  j_score_macro = jaccard_score(y_true, y_pred, average='macro')
489
493
 
490
494
  overall_report = (
491
- f"Overall Multi-Label Metrics:\n" # No threshold to report here
495
+ f"Overall Multi-Label Metrics:\n"
492
496
  f"--------------------------------------------------\n"
493
497
  f"Hamming Loss: {h_loss:.4f}\n"
494
498
  f"Jaccard Score (micro): {j_score_micro:.4f}\n"
@@ -498,14 +502,82 @@ def multi_label_classification_metrics(
498
502
  # print(overall_report)
499
503
  overall_report_path = save_dir_path / "classification_report.txt"
500
504
  overall_report_path.write_text(overall_report)
505
+
506
+ # --- Save Classification Report Heatmap (Multi-label) ---
507
+ try:
508
+ # Generate full report as dict
509
+ full_report_dict = classification_report(y_true, y_pred, target_names=target_names, output_dict=True)
510
+ report_df = pd.DataFrame(full_report_dict)
511
+
512
+ # Cleanup
513
+ # Remove 'accuracy' column if it exists
514
+ report_df = report_df.drop(columns=['accuracy'], errors='ignore')
515
+
516
+ # Remove 'support' row explicitly
517
+ if 'support' in report_df.index:
518
+ report_df = report_df.drop(index='support')
519
+
520
+ # Transpose: Rows = Classes/Averages, Cols = Metrics
521
+ plot_df = report_df.T
522
+
523
+ # Dynamic Height
524
+ fig_height = max(5.0, len(plot_df.index) * 0.5 + 4.0)
525
+ fig_width = 8.0
526
+
527
+ fig_heat, ax_heat = plt.subplots(figsize=(fig_width, fig_height), dpi=_EvaluationConfig.DPI)
528
+
529
+ # Plot
530
+ sns.heatmap(plot_df,
531
+ annot=True,
532
+ cmap=format_config.cmap,
533
+ fmt='.2f',
534
+ vmin=0.0,
535
+ vmax=1.0,
536
+ cbar_kws={'shrink': 0.9})
537
+
538
+ ax_heat.set_title("Classification Report Heatmap", pad=_EvaluationConfig.LABEL_PADDING, fontsize=cm_font_size)
539
+
540
+ # manually increase the font size of the elements
541
+ for text in ax_heat.texts:
542
+ text.set_fontsize(cm_tick_size)
543
+
544
+ cbar = ax_heat.collections[0].colorbar
545
+ cbar.ax.tick_params(labelsize=cm_tick_size - 4) # type: ignore
546
+
547
+ ax_heat.tick_params(axis='x', labelsize=cm_tick_size, pad=_EvaluationConfig.LABEL_PADDING)
548
+ ax_heat.tick_params(axis='y', labelsize=cm_tick_size, pad=_EvaluationConfig.LABEL_PADDING, rotation=0)
549
+
550
+ plt.tight_layout()
551
+ heatmap_path = save_dir_path / "classification_report_heatmap.svg"
552
+ plt.savefig(heatmap_path)
553
+ _LOGGER.info(f"📊 Report heatmap saved as '{heatmap_path.name}'")
554
+ plt.close(fig_heat)
555
+
556
+ except Exception as e:
557
+ _LOGGER.error(f"Could not generate multi-label classification report heatmap: {e}")
501
558
 
502
559
  # --- Per-Label Metrics and Plots ---
503
560
  for i, name in enumerate(target_names):
504
- print(f" -> Evaluating label: '{name}'")
561
+ # strip whitespace from name
562
+ name = name.strip()
563
+
564
+ # print(f" -> Evaluating label: '{name}'")
505
565
  true_i = y_true[:, i]
506
566
  pred_i = y_pred[:, i] # Use passed-in y_pred
507
567
  prob_i = y_prob[:, i] # Use passed-in y_prob
508
568
  sanitized_name = sanitize_filename(name)
569
+
570
+ # if name is too long, just take the first letter of each word. Each word might be separated by space or underscore
571
+ if len(name) >= _EvaluationConfig.NAME_LIMIT:
572
+ parts = [w for w in name.replace("_", " ").split() if w]
573
+ abbr = "".join(p[0].upper() for p in parts)
574
+ # keep only alpha numeric chars
575
+ abbr = "".join(ch for ch in abbr if ch.isalnum())
576
+ if not abbr:
577
+ # fallback to a sanitized, truncated version of the original name
578
+ abbr = sanitize_filename(name)[: _EvaluationConfig.NAME_LIMIT]
579
+ _LOGGER.warning(f"Using abbreviated name '{abbr}' for '{name}' plots.")
580
+ name = abbr
509
581
 
510
582
  # --- Save Classification Report for the label (uses y_pred) ---
511
583
  report_text = classification_report(true_i, pred_i)
@@ -537,7 +609,7 @@ def multi_label_classification_metrics(
537
609
  ax_cm.tick_params(axis='y', labelsize=ytick_size)
538
610
 
539
611
  # Set titles and labels with padding
540
- ax_cm.set_title(f"Confusion Matrix for '{name}'", pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
612
+ ax_cm.set_title(f"Confusion Matrix - {name}", pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
541
613
  ax_cm.set_xlabel(ax_cm.get_xlabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
542
614
  ax_cm.set_ylabel(ax_cm.get_ylabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
543
615
 
@@ -594,7 +666,7 @@ def multi_label_classification_metrics(
594
666
  ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=format_config.ROC_PR_line) # Use config color
595
667
  ax_roc.plot([0, 1], [0, 1], 'k--')
596
668
 
597
- ax_roc.set_title(f'ROC Curve for "{name}"', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
669
+ ax_roc.set_title(f'ROC Curve - {name}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
598
670
  ax_roc.set_xlabel('False Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
599
671
  ax_roc.set_ylabel('True Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
600
672
 
@@ -616,7 +688,7 @@ def multi_label_classification_metrics(
616
688
  ap_score = average_precision_score(true_i, prob_i)
617
689
  fig_pr, ax_pr = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
618
690
  ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=format_config.ROC_PR_line) # Use config color
619
- ax_pr.set_title(f'Precision-Recall Curve for "{name}"', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
691
+ ax_pr.set_title(f'PR Curve - {name}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
620
692
  ax_pr.set_xlabel('Recall', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
621
693
  ax_pr.set_ylabel('Precision', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
622
694
 
@@ -659,10 +731,10 @@ def multi_label_classification_metrics(
659
731
  prob_true,
660
732
  marker='o',
661
733
  linewidth=2,
662
- label=f"Calibration for '{name}'",
734
+ label=f"Model Calibration",
663
735
  color=format_config.ROC_PR_line)
664
736
 
665
- ax_cal.set_title(f'Reliability Curve for "{name}"', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
737
+ ax_cal.set_title(f'Calibration - {name}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
666
738
  ax_cal.set_xlabel('Mean Predicted Probability', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
667
739
  ax_cal.set_ylabel('Fraction of Positives', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
668
740
 
@@ -2,6 +2,7 @@ from ._analysis import (
2
2
  summarize_dataframe,
3
3
  show_null_columns,
4
4
  match_and_filter_columns_by_regex,
5
+ check_class_balance,
5
6
  )
6
7
 
7
8
  from ._cleaning import (
@@ -28,6 +29,7 @@ from ._features import (
28
29
  split_continuous_binary,
29
30
  split_continuous_categorical_targets,
30
31
  encode_categorical_features,
32
+ encode_classification_target,
31
33
  reconstruct_one_hot,
32
34
  reconstruct_binary,
33
35
  reconstruct_multibinary,
@@ -44,7 +46,6 @@ from .._core import _imprimir_disponibles
44
46
 
45
47
  __all__ = [
46
48
  "summarize_dataframe",
47
- "show_null_columns",
48
49
  "drop_constant_columns",
49
50
  "drop_rows_with_missing_data",
50
51
  "drop_columns_with_missing_data",
@@ -61,10 +62,13 @@ __all__ = [
61
62
  "plot_categorical_vs_target",
62
63
  "plot_correlation_heatmap",
63
64
  "encode_categorical_features",
65
+ "encode_classification_target",
64
66
  "finalize_feature_schema",
65
67
  "apply_feature_schema",
66
68
  "reconstruct_from_schema",
67
69
  "match_and_filter_columns_by_regex",
70
+ "show_null_columns",
71
+ "check_class_balance",
68
72
  "standardize_percentages",
69
73
  "reconstruct_one_hot",
70
74
  "reconstruct_binary",
@@ -16,6 +16,7 @@ __all__ = [
16
16
  "summarize_dataframe",
17
17
  "show_null_columns",
18
18
  "match_and_filter_columns_by_regex",
19
+ "check_class_balance",
19
20
  ]
20
21
 
21
22
 
@@ -212,3 +213,151 @@ def match_and_filter_columns_by_regex(
212
213
 
213
214
  return filtered_df, matched_columns
214
215
 
216
+
217
+ def check_class_balance(
218
+ df: pd.DataFrame,
219
+ target: Union[str, list[str]],
220
+ plot_to_dir: Optional[Union[str, Path]] = None,
221
+ plot_filename: str = "Class_Balance"
222
+ ) -> pd.DataFrame:
223
+ """
224
+ Analyzes the class balance for classification targets.
225
+
226
+ Handles two cases:
227
+ 1. Single Column (Binary/Multi-class): Calculates frequency of each unique value.
228
+ 2. List of Columns (Multi-label Binary): Calculates the frequency of positive values (1) per column.
229
+
230
+ Args:
231
+ df (pd.DataFrame): The input DataFrame.
232
+ target (str | list[str]): The target column name (for single/multi-class classification)
233
+ or list of column names (for multi-label-binary classification).
234
+ plot_to_dir (str | Path | None): Directory to save the balance plot.
235
+ plot_filename (str): Filename for the plot (without extension).
236
+
237
+ Returns:
238
+ pd.DataFrame: Summary table of counts and percentages.
239
+ """
240
+ # Early fail for empty DataFrame and handle list of targets with only one item
241
+ if df.empty:
242
+ _LOGGER.error("Input DataFrame is empty.")
243
+ raise ValueError()
244
+
245
+ if isinstance(target, list):
246
+ if len(target) == 0:
247
+ _LOGGER.error("Target list is empty.")
248
+ raise ValueError()
249
+ elif len(target) == 1:
250
+ target = target[0] # Simplify to single column case
251
+
252
+ # Case 1: Single Target (Binary or Multi-class)
253
+ if isinstance(target, str):
254
+ if target not in df.columns:
255
+ _LOGGER.error(f"Target column '{target}' not found in DataFrame.")
256
+ raise ValueError()
257
+
258
+ # Calculate stats
259
+ counts = df[target].value_counts(dropna=False).sort_index()
260
+ percents = df[target].value_counts(normalize=True, dropna=False).sort_index() * 100
261
+
262
+ summary = pd.DataFrame({
263
+ 'Count': counts,
264
+ 'Percentage': percents.round(2)
265
+ })
266
+ summary.index.name = "Class"
267
+
268
+ # Plotting
269
+ if plot_to_dir:
270
+ try:
271
+ save_path = make_fullpath(plot_to_dir, make=True, enforce="directory")
272
+
273
+ plt.figure(figsize=(10, 6))
274
+ # Convert index to str to handle numeric classes cleanly on x-axis
275
+ x_labels = summary.index.astype(str)
276
+ bars = plt.bar(x_labels, summary['Count'], color='lightgreen', edgecolor='black', alpha=0.7)
277
+
278
+ plt.title(f"Class Balance: {target}")
279
+ plt.xlabel(target)
280
+ plt.ylabel("Count")
281
+ plt.grid(axis='y', linestyle='--', alpha=0.5)
282
+
283
+ # Add percentage labels on top of bars
284
+ for bar, pct in zip(bars, summary['Percentage']):
285
+ height = bar.get_height()
286
+ plt.text(bar.get_x() + bar.get_width()/2, height,
287
+ f'{pct:.1f}%', ha='center', va='bottom', fontsize=10)
288
+
289
+ plt.tight_layout()
290
+ full_filename = sanitize_filename(plot_filename) + ".svg"
291
+ plt.savefig(save_path / full_filename, format='svg', bbox_inches="tight")
292
+ plt.close()
293
+ _LOGGER.info(f"Saved class balance plot: '{full_filename}'")
294
+ except Exception as e:
295
+ _LOGGER.error(f"Failed to plot class balance. Error: {e}")
296
+ plt.close()
297
+
298
+ return summary
299
+
300
+ # Case 2: Multi-label (List of binary columns)
301
+ elif isinstance(target, list):
302
+ missing_cols = [t for t in target if t not in df.columns]
303
+ if missing_cols:
304
+ _LOGGER.error(f"Target columns not found: {missing_cols}")
305
+ raise ValueError()
306
+
307
+ stats = []
308
+ for col in target:
309
+ # Assume 0/1 or False/True. Sum gives the count of positives.
310
+ # We enforce numeric to be safe
311
+ try:
312
+ numeric_series = pd.to_numeric(df[col], errors='coerce').fillna(0)
313
+ pos_count = numeric_series.sum()
314
+ total_count = len(df)
315
+ pct = (pos_count / total_count) * 100
316
+ except Exception:
317
+ _LOGGER.warning(f"Column '{col}' could not be processed as numeric. Assuming 0 positives.")
318
+ pos_count = 0
319
+ pct = 0.0
320
+
321
+ stats.append({
322
+ 'Label': col,
323
+ 'Positive_Count': int(pos_count),
324
+ 'Positive_Percentage': round(pct, 2)
325
+ })
326
+
327
+ summary = pd.DataFrame(stats).set_index("Label").sort_values("Positive_Percentage", ascending=True)
328
+
329
+ # Plotting
330
+ if plot_to_dir:
331
+ try:
332
+ save_path = make_fullpath(plot_to_dir, make=True, enforce="directory")
333
+
334
+ # Dynamic height for many labels
335
+ height = max(6, len(target) * 0.4)
336
+ plt.figure(figsize=(10, height))
337
+
338
+ bars = plt.barh(summary.index, summary['Positive_Percentage'], color='lightgreen', edgecolor='black', alpha=0.7)
339
+
340
+ plt.title(f"Multi-label Binary Class Balance")
341
+ plt.xlabel("Positive Class Percentage (%)")
342
+ plt.xlim(0, 100)
343
+ plt.grid(axis='x', linestyle='--', alpha=0.5)
344
+
345
+ # Add count labels at the end of bars
346
+ for bar, count in zip(bars, summary['Positive_Count']):
347
+ width = bar.get_width()
348
+ plt.text(width + 1, bar.get_y() + bar.get_height()/2, f'{width:.1f}%', ha='left', va='center', fontsize=9)
349
+
350
+ plt.tight_layout()
351
+ full_filename = sanitize_filename(plot_filename) + ".svg"
352
+ plt.savefig(save_path / full_filename, format='svg', bbox_inches="tight")
353
+ plt.close()
354
+ _LOGGER.info(f"Saved multi-label balance plot: '{full_filename}'")
355
+ except Exception as e:
356
+ _LOGGER.error(f"Failed to plot class balance. Error: {e}")
357
+ plt.close()
358
+
359
+ return summary.sort_values("Positive_Percentage", ascending=False)
360
+
361
+ else:
362
+ _LOGGER.error("Target must be a string or a list of strings.")
363
+ raise TypeError()
@@ -3,7 +3,10 @@ from pandas.api.types import is_numeric_dtype, is_object_dtype
3
3
  import numpy as np
4
4
  from typing import Any, Optional, Union
5
5
  import re
6
+ import json
7
+ from pathlib import Path
6
8
 
9
+ from ..path_manager import make_fullpath
7
10
  from .._core import get_logger
8
11
 
9
12
 
@@ -15,6 +18,7 @@ __all__ = [
15
18
  "split_continuous_binary",
16
19
  "split_continuous_categorical_targets",
17
20
  "encode_categorical_features",
21
+ "encode_classification_target",
18
22
  "reconstruct_one_hot",
19
23
  "reconstruct_binary",
20
24
  "reconstruct_multibinary",
@@ -263,6 +267,78 @@ def encode_categorical_features(
263
267
  return df_encoded, mappings
264
268
 
265
269
 
270
+ def encode_classification_target(
271
+ df: pd.DataFrame,
272
+ target_col: str,
273
+ save_dir: Union[str, Path],
274
+ verbose: int = 2
275
+ ) -> tuple[pd.DataFrame, dict[str, int]]:
276
+ """
277
+ Encodes a target classification column into integers (0, 1, 2...) and saves the mapping to a JSON file.
278
+
279
+ This ensures that the target variable is in the correct numeric format for training
280
+ and provides a persistent artifact (the JSON file) to map predictions back to labels later.
281
+
282
+ Args:
283
+ df (pd.DataFrame): Input DataFrame.
284
+ target_col (str): Name of the target column to encode.
285
+ save_dir (str | Path): Directory where the class map JSON will be saved.
286
+ verbose (int): Verbosity level for logging.
287
+
288
+ Returns:
289
+ Tuple (Dataframe, Dict):
290
+ - A new DataFrame with the target column encoded as integers.
291
+ - The dictionary mapping original labels (str) to integers (int).
292
+ """
293
+ if target_col not in df.columns:
294
+ _LOGGER.error(f"Target column '{target_col}' not found in DataFrame.")
295
+ raise ValueError()
296
+
297
+ # Validation: Check for missing values in target
298
+ if df[target_col].isnull().any():
299
+ n_missing = df[target_col].isnull().sum()
300
+ _LOGGER.error(f"Target column '{target_col}' contains {n_missing} missing values. Please handle them before encoding.")
301
+ raise ValueError()
302
+
303
+ # Ensure directory exists
304
+ save_path = make_fullpath(save_dir, make=True, enforce="directory")
305
+ file_path = save_path / "class_map.json"
306
+
307
+ # Get unique values and sort them to ensure deterministic encoding (0, 1, 2...)
308
+ # Convert to string to ensure the keys in JSON are strings
309
+ unique_labels = sorted(df[target_col].astype(str).unique())
310
+
311
+ # Create mapping: { Label -> Integer }
312
+ class_map = {label: idx for idx, label in enumerate(unique_labels)}
313
+
314
+ # Apply mapping
315
+ # cast column to string to match the keys in class_map
316
+ df_encoded = df.copy()
317
+ df_encoded[target_col] = df_encoded[target_col].astype(str).map(class_map)
318
+
319
+ # Save to JSON
320
+ try:
321
+ with open(file_path, 'w', encoding='utf-8') as f:
322
+ json.dump(class_map, f, indent=4)
323
+
324
+ if verbose >= 2:
325
+ _LOGGER.info(f"Class mapping saved to: '{file_path}'")
326
+
327
+ if verbose >= 3:
328
+ _LOGGER.info(f"Target '{target_col}' encoded with {len(class_map)} classes.")
329
+ # Print a preview
330
+ if len(class_map) <= 10:
331
+ print(f" Mapping: {class_map}")
332
+ else:
333
+ print(f" Mapping (first 5): {dict(list(class_map.items())[:5])} ...")
334
+
335
+ except Exception as e:
336
+ _LOGGER.error(f"Failed to save class map JSON. Error: {e}")
337
+ raise IOError()
338
+
339
+ return df_encoded, class_map
340
+
341
+
266
342
  def reconstruct_one_hot(
267
343
  df: pd.DataFrame,
268
344
  features_to_reconstruct: list[Union[str, tuple[str, Optional[str]]]],
ml_tools/keys/_keys.py CHANGED
@@ -306,6 +306,7 @@ class _EvaluationConfig:
306
306
  LOSS_PLOT_LEGEND_SIZE = 24
307
307
  # CM settings
308
308
  CM_SIZE = (9, 8) # used for multi label binary classification confusion matrix
309
+ NAME_LIMIT = 20 # max number of characters for feature/label names in plots
309
310
 
310
311
  class _OneHotOtherPlaceholder:
311
312
  """Used internally by GUI_tools."""
@@ -0,0 +1,19 @@
1
+ from ._single_resampling import (
2
+ DragonResampler,
3
+ )
4
+
5
+ from ._multi_resampling import (
6
+ DragonMultiResampler,
7
+ )
8
+
9
+ from .._core import _imprimir_disponibles
10
+
11
+
12
+ __all__ = [
13
+ "DragonResampler",
14
+ "DragonMultiResampler",
15
+ ]
16
+
17
+
18
+ def info():
19
+ _imprimir_disponibles(__all__)
@@ -0,0 +1,49 @@
1
+ import polars as pl
2
+ import pandas as pd
3
+ from typing import Union
4
+ from abc import ABC, abstractmethod
5
+
6
+
7
+ __all__ = ["_DragonBaseResampler"]
8
+
9
+
10
+ class _DragonBaseResampler(ABC):
11
+ """
12
+ Base class for Dragon resamplers handling common I/O and state.
13
+ """
14
+ def __init__(self,
15
+ return_pandas: bool = False,
16
+ seed: int = 42):
17
+ self.return_pandas = return_pandas
18
+ self.seed = seed
19
+
20
+ def _convert_to_polars(self, df: Union[pd.DataFrame, pl.DataFrame]) -> pl.DataFrame:
21
+ """Standardizes input to Polars DataFrame."""
22
+ if isinstance(df, pd.DataFrame):
23
+ return pl.from_pandas(df)
24
+ return df
25
+
26
+ def _convert_to_pandas(self, df: pl.DataFrame) -> pd.DataFrame:
27
+ """Converts Polars DataFrame back to Pandas."""
28
+ return df.to_pandas(use_pyarrow_extension_array=False)
29
+
30
+ def _process_return(self, df: pl.DataFrame, shuffle: bool = True) -> Union[pd.DataFrame, pl.DataFrame]:
31
+ """
32
+ Finalizes the DataFrame:
33
+ 1. Global Shuffle (optional but recommended for ML).
34
+ 2. Conversion to Pandas (if requested).
35
+ """
36
+ if shuffle:
37
+ # Random shuffle of the final dataset
38
+ df = df.sample(fraction=1.0, seed=self.seed, with_replacement=False)
39
+
40
+ if self.return_pandas:
41
+ return self._convert_to_pandas(df)
42
+ return df
43
+
44
+ @abstractmethod
45
+ def describe_balance(self, df: Union[pd.DataFrame, pl.DataFrame], top_n: int = 10) -> None:
46
+ """
47
+ Prints a statistical summary of the target distribution.
48
+ """
49
+ pass
@@ -0,0 +1,184 @@
1
+ import polars as pl
2
+ import pandas as pd
3
+ import numpy as np
4
+ from typing import Union, Optional
5
+
6
+ from .._core import get_logger
7
+
8
+ from ._base_resampler import _DragonBaseResampler
9
+
10
+
11
+ _LOGGER = get_logger("DragonMultiResampler")
12
+
13
+
14
+ __all__ = [
15
+ "DragonMultiResampler",
16
+ ]
17
+
18
+
19
+ class DragonMultiResampler(_DragonBaseResampler):
20
+ """
21
+ A robust resampler for multi-label binary classification tasks using Polars.
22
+
23
+ It provides methods to downsample "all-negative" rows and balance the dataset
24
+ based on unique label combinations (Powerset).
25
+ """
26
+ def __init__(self,
27
+ target_columns: list[str],
28
+ return_pandas: bool = False,
29
+ seed: int = 42):
30
+ """
31
+ Args:
32
+ target_columns (List[str]): The list of binary target column names.
33
+ return_pandas (bool): Whether to return results as pandas DataFrame.
34
+ seed (int): Random seed for reproducibility.
35
+ """
36
+ super().__init__(return_pandas=return_pandas, seed=seed)
37
+ self.targets = target_columns
38
+
39
+ def downsample_all_negatives(self,
40
+ df: Union[pd.DataFrame, pl.DataFrame],
41
+ negative_ratio: float = 1.0,
42
+ verbose: int = 2) -> Union[pd.DataFrame, pl.DataFrame]:
43
+ """
44
+ Downsamples rows where ALL target columns are 0 ("background" class).
45
+
46
+ Args:
47
+ df (pd.DataFrame | pl.DataFrame): Input DataFrame.
48
+ negative_ratio (float): Ratio of negatives to positives to retain.
49
+ verbose (int): Verbosity level for logging.
50
+
51
+ Returns:
52
+ Dataframe: Resampled DataFrame.
53
+ """
54
+ df_pl = self._convert_to_polars(df)
55
+
56
+ # 1. Identify "All Negative" vs "Has Signal"
57
+ fold_expr = pl.sum_horizontal(pl.col(self.targets)).cast(pl.UInt32)
58
+
59
+ df_pos = df_pl.filter(fold_expr > 0)
60
+ df_neg = df_pl.filter(fold_expr == 0)
61
+
62
+ n_pos = df_pos.height
63
+ n_neg_original = df_neg.height
64
+
65
+ if n_pos == 0:
66
+ if verbose >= 1:
67
+ _LOGGER.warning("No positive cases found in any label. Returning original DataFrame.")
68
+ return self._process_return(df_pl, shuffle=False)
69
+
70
+ # 2. Calculate target count for negatives
71
+ target_n_neg = int(n_pos * negative_ratio)
72
+
73
+ # 3. Sample if necessary
74
+ if n_neg_original > target_n_neg:
75
+ if verbose >= 2:
76
+ _LOGGER.info(f"📉 Downsampling 'All-Negative' rows from {n_neg_original} to {target_n_neg}")
77
+
78
+ # Here we use standard sampling because we are not grouping
79
+ df_neg_sampled = df_neg.sample(n=target_n_neg, seed=self.seed, with_replacement=False)
80
+ df_resampled = pl.concat([df_pos, df_neg_sampled])
81
+
82
+ return self._process_return(df_resampled)
83
+ else:
84
+ if verbose >= 1:
85
+ _LOGGER.warning(f"Negative count ({n_neg_original}) is already below target ({target_n_neg}). No downsampling applied.")
86
+ return self._process_return(df_pl, shuffle=False)
87
+
88
+ def balance_powerset(self,
89
+ df: Union[pd.DataFrame, pl.DataFrame],
90
+ max_samples_per_combination: Optional[int] = None,
91
+ quantile_limit: float = 0.90,
92
+ verbose: int = 2) -> Union[pd.DataFrame, pl.DataFrame]:
93
+ """
94
+ Groups data by unique label combinations (Powerset) and downsamples
95
+ majority combinations.
96
+
97
+ Args:
98
+ df (pd.DataFrame | pl.DataFrame): Input DataFrame.
99
+ max_samples_per_combination (int | None): Fixed cap per combination.
100
+ If None, uses quantile_limit to determine cap.
101
+ quantile_limit (float): Quantile to determine cap if max_samples_per_combination is None.
102
+ verbose (int): Verbosity level for logging.
103
+
104
+ Returns:
105
+ Dataframe: Resampled DataFrame.
106
+ """
107
+ df_pl = self._convert_to_polars(df)
108
+
109
+ # 1. Create a hash/structural representation of the targets for grouping
110
+ df_lazy = df_pl.lazy().with_columns(
111
+ pl.concat_list(pl.col(self.targets)).alias("_powerset_key")
112
+ )
113
+
114
+ # 2. Calculate frequencies
115
+ # We need to collect partially to calculate the quantile cap
116
+ combo_counts = df_lazy.group_by("_powerset_key").len().collect()
117
+
118
+ # Determine the Cap
119
+ if max_samples_per_combination is None:
120
+ # Handle potential None from quantile (satisfies linter)
121
+ q_val = combo_counts["len"].quantile(quantile_limit)
122
+
123
+ if q_val is None:
124
+ if verbose >= 1:
125
+ _LOGGER.warning("Data empty or insufficient to calculate quantile. Returning original.")
126
+ return self._process_return(df_pl, shuffle=False)
127
+
128
+ cap_size = int(q_val)
129
+
130
+ if verbose >= 3:
131
+ _LOGGER.info(f"📊 Auto-calculated Powerset Cap: {cap_size} samples (based on {quantile_limit} quantile).")
132
+ else:
133
+ cap_size = max_samples_per_combination
134
+
135
+ # 3. Apply Stratified Sampling / Capping (Randomized)
136
+ df_balanced = (
137
+ df_lazy
138
+ .filter(
139
+ pl.int_range(0, pl.len())
140
+ .shuffle(seed=self.seed)
141
+ .over("_powerset_key")
142
+ < cap_size
143
+ )
144
+ .drop("_powerset_key")
145
+ .collect()
146
+ )
147
+
148
+ if verbose >= 2:
149
+ original_count = df_pl.height
150
+ new_count = df_balanced.height
151
+ _LOGGER.info(f"⚖️ Powerset Balancing: Reduced from {original_count} to {new_count} rows.")
152
+
153
+ return self._process_return(df_balanced)
154
+
155
+ def describe_balance(self, df: Union[pd.DataFrame, pl.DataFrame], top_n: int = 10) -> None:
156
+ df_pl = self._convert_to_polars(df)
157
+ total_rows = df_pl.height
158
+
159
+ message_1 = f"\n📊 --- Target Balance Report ({total_rows} samples) ---\n🎯 Multi-Targets: {len(self.targets)} columns"
160
+
161
+ # A. Individual Label Counts
162
+ sums = df_pl.select([
163
+ pl.sum(col).alias(col) for col in self.targets
164
+ ]).transpose(include_header=True, header_name="Label", column_names=["Count"])
165
+
166
+ sums = sums.with_columns(
167
+ (pl.col("Count") / total_rows * 100).round(2).alias("Percentage(%)")
168
+ ).sort("Count", descending=True)
169
+
170
+ message_1 += "\n🔹 Individual Label Frequencies:"
171
+
172
+ # B. Powerset (Combination) Counts
173
+ message_2 = f"🔹 Top {top_n} Label Combinations (Powerset):"
174
+
175
+ combo_stats = (
176
+ df_pl.group_by(self.targets)
177
+ .len(name="Count")
178
+ .sort("Count", descending=True)
179
+ .with_columns(
180
+ (pl.col("Count") / total_rows * 100).round(2).alias("Percentage(%)")
181
+ )
182
+ )
183
+
184
+ _LOGGER.info(f"{message_1}\n{sums.head(top_n)}\n{message_2}\n{combo_stats.head(top_n)}")
@@ -0,0 +1,113 @@
1
+ import polars as pl
2
+ import pandas as pd
3
+ import numpy as np
4
+ from typing import Union
5
+
6
+ from .._core import get_logger
7
+
8
+ from ._base_resampler import _DragonBaseResampler
9
+
10
+
11
+ _LOGGER = get_logger("DragonResampler")
12
+
13
+
14
+ __all__ = [
15
+ "DragonResampler",
16
+ ]
17
+
18
+
19
+ class DragonResampler(_DragonBaseResampler):
20
+ """
21
+ A resampler for Single-Target Classification tasks (Binary or Multiclass).
22
+
23
+ It balances classes by downsampling majority classes relative to the size of the minority class.
24
+ """
25
+ def __init__(self,
26
+ target_column: str,
27
+ return_pandas: bool = False,
28
+ seed: int = 42):
29
+ """
30
+ Args:
31
+ target_column (str): The name of the single target column.
32
+ return_pandas (bool): Whether to return results as pandas DataFrame.
33
+ seed (int): Random seed for reproducibility.
34
+ """
35
+ super().__init__(return_pandas=return_pandas, seed=seed)
36
+ self.target = target_column
37
+
38
+ def balance_classes(self,
39
+ df: Union[pd.DataFrame, pl.DataFrame],
40
+ majority_ratio: float = 1.0,
41
+ verbose: int = 2) -> Union[pd.DataFrame, pl.DataFrame]:
42
+ """
43
+ Downsamples all classes to match the minority class size (scaled by a ratio).
44
+ """
45
+ df_pl = self._convert_to_polars(df)
46
+
47
+ # 1. Calculate Class Counts
48
+ counts = df_pl.group_by(self.target).len().sort("len")
49
+
50
+ if counts.height == 0:
51
+ _LOGGER.error("DataFrame is empty or target column missing.")
52
+ return self._process_return(df_pl, shuffle=False)
53
+
54
+ # 2. Identify Statistics
55
+ min_val = counts["len"].min()
56
+ max_val = counts["len"].max()
57
+
58
+ if min_val is None or max_val is None:
59
+ _LOGGER.error("Failed to calculate class statistics (unexpected None).")
60
+ raise ValueError()
61
+
62
+ minority_count: int = min_val # type: ignore
63
+ majority_count: int = max_val # type: ignore
64
+
65
+ # Calculate the cap
66
+ cap_size = int(minority_count * majority_ratio)
67
+
68
+ if verbose >= 3:
69
+ _LOGGER.info(f"📊 Class Distribution:\n{counts}")
70
+ _LOGGER.info(f"🎯 Strategy: Cap majorities at {cap_size}")
71
+
72
+ # Optimization: If data is already balanced enough
73
+ if majority_count <= cap_size:
74
+ if verbose >= 2:
75
+ _LOGGER.info("Data is already within the requested balance ratio.")
76
+ return self._process_return(df_pl, shuffle=False)
77
+
78
+ # 3. Apply Downsampling (Randomized)
79
+ # We generate a random range index per group and filter by it.
80
+ # This ensures we pick a random subset, not the first N rows.
81
+ df_balanced = (
82
+ df_pl.lazy()
83
+ .filter(
84
+ pl.int_range(0, pl.len())
85
+ .shuffle(seed=self.seed)
86
+ .over(self.target)
87
+ < cap_size
88
+ )
89
+ .collect()
90
+ )
91
+
92
+ if verbose >= 2:
93
+ reduced_count = df_balanced.height
94
+ _LOGGER.info(f"⚖️ Balancing Complete: {df_pl.height} -> {reduced_count} rows.")
95
+
96
+ return self._process_return(df_balanced)
97
+
98
+ def describe_balance(self, df: Union[pd.DataFrame, pl.DataFrame], top_n: int = 10) -> None:
99
+ df_pl = self._convert_to_polars(df)
100
+ total_rows = df_pl.height
101
+
102
+ message = f"\n📊 --- Target Balance Report ({total_rows} samples) ---\n🎯 Single Target: '{self.target}'"
103
+
104
+ stats = (
105
+ df_pl.group_by(self.target)
106
+ .len(name="Count")
107
+ .sort("Count", descending=True)
108
+ .with_columns(
109
+ (pl.col("Count") / total_rows * 100).round(2).alias("Percentage(%)")
110
+ )
111
+ )
112
+
113
+ _LOGGER.info(f"{message}\n{stats.head(top_n)}")