dragon-ml-toolbox 19.6.0__py3-none-any.whl → 19.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 19.6.0
3
+ Version: 19.7.0
4
4
  Summary: Complete pipelines and helper tools for data science and machine learning projects.
5
5
  Author-email: Karl Luigi Loza Vidaurre <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -1,5 +1,5 @@
1
- dragon_ml_toolbox-19.6.0.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
2
- dragon_ml_toolbox-19.6.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=XBLtvGjvBf-q93a5iylHj94Lm78UzInC-3Cii01jc6I,3127
1
+ dragon_ml_toolbox-19.7.0.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
2
+ dragon_ml_toolbox-19.7.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=XBLtvGjvBf-q93a5iylHj94Lm78UzInC-3Cii01jc6I,3127
3
3
  ml_tools/ETL_cleaning.py,sha256=cKXyRFaaFs_beAGDnQM54xnML671kq-yJEGjHafW-20,351
4
4
  ml_tools/ETL_engineering.py,sha256=cwh1FhtNdUHllUDvho-x3SIVj4KwG_rFQR6VYzWUg0U,898
5
5
  ml_tools/GUI_tools.py,sha256=O89rG8WQv6GY1DiphQjIsPzXFCQID6te7q_Sgt1iTkQ,294
@@ -58,12 +58,12 @@ ml_tools/_core/_MICE_imputation.py,sha256=_juIymUnNDRWjSLepL8Ee_PncoShbxjR7YtqTt
58
58
  ml_tools/_core/_ML_callbacks.py,sha256=qtCrVFHTq-nk4NIsAdwIkfkKwFXX6I-6PoCgqZELp70,16734
59
59
  ml_tools/_core/_ML_chaining_inference.py,sha256=vXUPZzuQ2yKU71kkvUsE0xPo0hN-Yu6gfnL0JbXoRjI,7783
60
60
  ml_tools/_core/_ML_chaining_utilities.py,sha256=nsYowgRbkIYuzRiHlqsM3tnC3c-8O73CY8DHUF14XL0,19248
61
- ml_tools/_core/_ML_configuration.py,sha256=i52e_NH37fTDU3JmePPlp5Vq9UDukY9ro1q4SVai-Zo,45213
61
+ ml_tools/_core/_ML_configuration.py,sha256=t_6p_slOPhmy04wJcQj6D_bJyZRaXlsIeWIiULaJnXc,48716
62
62
  ml_tools/_core/_ML_configuration_pytab.py,sha256=C3e4iScqdRePVDoqnic6xXMOW7DNYqpgTCeaFDyMdL4,3286
63
63
  ml_tools/_core/_ML_datasetmaster.py,sha256=yU1BMtzz6XumMWCetVACrRLk7WJQwmYhaQ-VAWu9Ots,32043
64
- ml_tools/_core/_ML_evaluation.py,sha256=KfUSgpww0IlmJaM_RQJvsFY8U-N2d2AbLYJzsfyWRTU,30760
64
+ ml_tools/_core/_ML_evaluation.py,sha256=bu8qlYzhWSC1B7wNfCC5TSF-oed-uP8EF7TV45VTiBM,37325
65
65
  ml_tools/_core/_ML_evaluation_captum.py,sha256=a69jnghIzE9qppuw2vzTBMdTErnZkDkTA3MPUUYjsS4,19212
66
- ml_tools/_core/_ML_evaluation_multi.py,sha256=CFFmGUt2733VpHSzbRpHXSUnz8Vfe0rWQ7NLaLSMmv0,20710
66
+ ml_tools/_core/_ML_evaluation_multi.py,sha256=n_AJbKF58DMUrYqJutwPFV5z6sNssDPA1Gl05IfPG5s,23647
67
67
  ml_tools/_core/_ML_finalize_handler.py,sha256=0eZ_0N2L5aUUIJUgvhAQ-rbd8XbE9UmNqTKSJq09uTI,6987
68
68
  ml_tools/_core/_ML_inference.py,sha256=5swm2lnsrDLalBnCm7gZPlDucX4yNCq5vn7ck3SW_4Q,29791
69
69
  ml_tools/_core/_ML_models.py,sha256=8FUx4-TVghlBF9srh1_5UxovrWPU7YEZ6XXLqwJei88,27974
@@ -73,13 +73,13 @@ ml_tools/_core/_ML_optimization.py,sha256=b1qfHiGyvVoj-ENqDbHTf1jNx55niUWE9KEZJv
73
73
  ml_tools/_core/_ML_optimization_pareto.py,sha256=7jjV7i-A_J8vizDKg2ZIWNMVRu5oJokRmDbIkhofdlk,34831
74
74
  ml_tools/_core/_ML_scaler.py,sha256=Nhu6qli_QezHQi5NKhRb8Z51bBJgzk2nEp_yW4B9H4U,8134
75
75
  ml_tools/_core/_ML_sequence_datasetmaster.py,sha256=0YVOPf-y4ZNdgUxropXUWrmInNyGYaUYprYvXf31n9U,17811
76
- ml_tools/_core/_ML_sequence_evaluation.py,sha256=cBbd_1qPPiggjngzkD2NBwIRWkArKx5VSYcGpyIemj8,8077
76
+ ml_tools/_core/_ML_sequence_evaluation.py,sha256=AiPHtZ9DRpE6zL9n3Tp5eGGD9vrYRkLbZ0Nc274mL7I,8069
77
77
  ml_tools/_core/_ML_sequence_inference.py,sha256=zd3hBwOtLmjAV4JtdB2qFY9GxhysajFufATdy8fjGTE,16316
78
78
  ml_tools/_core/_ML_sequence_models.py,sha256=5qcEYLU6wDePBITnikBrj_H9mCvyJmElKa3HiWGXhZs,5639
79
79
  ml_tools/_core/_ML_trainer.py,sha256=hSsudWrlYWpi53DXIlKI6ovVhz7xLrQ8oKIDJOXf4Eg,117747
80
80
  ml_tools/_core/_ML_utilities.py,sha256=yXVKow-bgpahMChpp7iUlSxAEtgityXwC54FPReeNNA,30487
81
81
  ml_tools/_core/_ML_vision_datasetmaster.py,sha256=8EsE7luzphVlwBXdOsOwsFfz1D4UIUSEQtqHlM0Vf-o,67084
82
- ml_tools/_core/_ML_vision_evaluation.py,sha256=afn9T3B6KH2-fv2mGj3m-apSZc0R5H3as9iMjh3lchg,11625
82
+ ml_tools/_core/_ML_vision_evaluation.py,sha256=BSLf9xrGpaR02Dhkf-fAbgxSpwRjf7DruNIcQadl7qg,11631
83
83
  ml_tools/_core/_ML_vision_inference.py,sha256=6K9gMFjAAZKfLAIQlOkm_I9hvCPmO--9-1vnskQRk0I,20190
84
84
  ml_tools/_core/_ML_vision_models.py,sha256=oUik-RLxFvZFZCtFztjkSfFYgJuRx4QzfwHVY1ny4Sc,26217
85
85
  ml_tools/_core/_ML_vision_transformers.py,sha256=imjL9h5kwpfuRn9rBelNpgtrdU-EecBEcHMFZMXTeZA,15303
@@ -92,7 +92,7 @@ ml_tools/_core/_ensemble_evaluation.py,sha256=17lWl4bWLT1BAMv_fhGf2D3wy-F4jx0Hgn
92
92
  ml_tools/_core/_ensemble_inference.py,sha256=PfZG-r65Vw3IAmBJZg9W0zYGEe-QbhfUh_rd2ho-rr8,8610
93
93
  ml_tools/_core/_ensemble_learning.py,sha256=X8ghbjDOLMENCWdISXLhDlHQtR3C6SW1tkTBAcfRRPY,22016
94
94
  ml_tools/_core/_excel_handler.py,sha256=gV4rSIsiowb0xllpEJxzUKaYDDVpmP_lxs9wZA76-cc,14050
95
- ml_tools/_core/_keys.py,sha256=ZfTHB2j1LiNmKTs1MThfmdz_05IWYNAMet_1CtD3oR0,6376
95
+ ml_tools/_core/_keys.py,sha256=4RE-ZuCJkUmqefz-dc3qrVbftqVAWkunZFrP2yJjpCU,6740
96
96
  ml_tools/_core/_logger.py,sha256=86Ge0sDE_WgwsZBglQRYPyFYX3lcsIo0NzszNPzlxuk,5254
97
97
  ml_tools/_core/_math_utilities.py,sha256=IlXAiZgTcLtus03jJOBOyF9ZCQDf8qLGjrCHu9Mrgak,9091
98
98
  ml_tools/_core/_models_advanced_base.py,sha256=ceW0V_CcfOnSFqHlxUhVU8-5mtQq4tFyo8TX-xVexrY,4982
@@ -104,7 +104,7 @@ ml_tools/_core/_schema.py,sha256=TM5WVVMoKOvr_Bc2z34sU_gzKlM465PRKTgdZaEOkGY,140
104
104
  ml_tools/_core/_script_info.py,sha256=21r83LV3RubsNZ_RTEUON6RbDf7Mh4_udweNcvdF_Fk,212
105
105
  ml_tools/_core/_serde.py,sha256=tsI4EO2Y7jrBMmbQ1pinDsPOrOg-SaPuB-Dt40q0taE,5609
106
106
  ml_tools/_core/_utilities.py,sha256=iA8fLWdhsIx4ut2Dp8M_OyU0Y3PPLgGdIklyl17x6xk,22560
107
- dragon_ml_toolbox-19.6.0.dist-info/METADATA,sha256=4VcRoS7xXnqsfy1bL11_n1beQjtVGwxNt2DN5wLgDzk,8764
108
- dragon_ml_toolbox-19.6.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
109
- dragon_ml_toolbox-19.6.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
110
- dragon_ml_toolbox-19.6.0.dist-info/RECORD,,
107
+ dragon_ml_toolbox-19.7.0.dist-info/METADATA,sha256=64wv9eyG4FCm-QRo-9S0wSwARsYo6I-ULnRli6t6UxU,8764
108
+ dragon_ml_toolbox-19.7.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
109
+ dragon_ml_toolbox-19.7.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
110
+ dragon_ml_toolbox-19.7.0.dist-info/RECORD,,
@@ -65,7 +65,11 @@ class _BaseClassificationFormat:
65
65
  cmap: str="BuGn",
66
66
  ROC_PR_line: str='darkorange',
67
67
  calibration_bins: int=15,
68
- font_size: int=16) -> None:
68
+ xtick_size: int=22,
69
+ ytick_size: int=22,
70
+ legend_size: int=26,
71
+ font_size: int=26,
72
+ cm_font_size: int=26) -> None:
69
73
  """
70
74
  Initializes the formatting configuration for single-label classification metrics.
71
75
 
@@ -84,6 +88,14 @@ class _BaseClassificationFormat:
84
88
  creating the calibration (reliability) plot.
85
89
 
86
90
  font_size (int): The base font size to apply to the plots.
91
+
92
+ xtick_size (int): Font size for x-axis tick labels.
93
+
94
+ ytick_size (int): Font size for y-axis tick labels.
95
+
96
+ legend_size (int): Font size for plot legends.
97
+
98
+ cm_font_size (int): Font size for the confusion matrix.
87
99
 
88
100
  <br>
89
101
 
@@ -97,13 +109,21 @@ class _BaseClassificationFormat:
97
109
  self.ROC_PR_line = ROC_PR_line
98
110
  self.calibration_bins = calibration_bins
99
111
  self.font_size = font_size
112
+ self.xtick_size = xtick_size
113
+ self.ytick_size = ytick_size
114
+ self.legend_size = legend_size
115
+ self.cm_font_size = cm_font_size
100
116
 
101
117
  def __repr__(self) -> str:
102
118
  parts = [
103
119
  f"cmap='{self.cmap}'",
104
120
  f"ROC_PR_line='{self.ROC_PR_line}'",
105
121
  f"calibration_bins={self.calibration_bins}",
106
- f"font_size={self.font_size}"
122
+ f"font_size={self.font_size}",
123
+ f"xtick_size={self.xtick_size}",
124
+ f"ytick_size={self.ytick_size}",
125
+ f"legend_size={self.legend_size}",
126
+ f"cm_font_size={self.cm_font_size}"
107
127
  ]
108
128
  return f"{self.__class__.__name__}({', '.join(parts)})"
109
129
 
@@ -115,7 +135,10 @@ class _BaseMultiLabelFormat:
115
135
  def __init__(self,
116
136
  cmap: str = "BuGn",
117
137
  ROC_PR_line: str='darkorange',
118
- font_size: int = 16) -> None:
138
+ font_size: int = 25,
139
+ xtick_size: int=20,
140
+ ytick_size: int=20,
141
+ legend_size: int=23) -> None:
119
142
  """
120
143
  Initializes the formatting configuration for multi-label classification metrics.
121
144
 
@@ -132,6 +155,12 @@ class _BaseMultiLabelFormat:
132
155
 
133
156
  font_size (int): The base font size to apply to the plots.
134
157
 
158
+ xtick_size (int): Font size for x-axis tick labels.
159
+
160
+ ytick_size (int): Font size for y-axis tick labels.
161
+
162
+ legend_size (int): Font size for plot legends.
163
+
135
164
  <br>
136
165
 
137
166
  ### [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
@@ -143,12 +172,18 @@ class _BaseMultiLabelFormat:
143
172
  self.cmap = cmap
144
173
  self.ROC_PR_line = ROC_PR_line
145
174
  self.font_size = font_size
175
+ self.xtick_size = xtick_size
176
+ self.ytick_size = ytick_size
177
+ self.legend_size = legend_size
146
178
 
147
179
  def __repr__(self) -> str:
148
180
  parts = [
149
181
  f"cmap='{self.cmap}'",
150
182
  f"ROC_PR_line='{self.ROC_PR_line}'",
151
- f"font_size={self.font_size}"
183
+ f"font_size={self.font_size}",
184
+ f"xtick_size={self.xtick_size}",
185
+ f"ytick_size={self.ytick_size}",
186
+ f"legend_size={self.legend_size}"
152
187
  ]
153
188
  return f"{self.__class__.__name__}({', '.join(parts)})"
154
189
 
@@ -158,12 +193,14 @@ class _BaseRegressionFormat:
158
193
  [PRIVATE] Base configuration for regression metrics.
159
194
  """
160
195
  def __init__(self,
161
- font_size: int=16,
196
+ font_size: int=26,
162
197
  scatter_color: str='tab:blue',
163
198
  scatter_alpha: float=0.6,
164
199
  ideal_line_color: str='k',
165
200
  residual_line_color: str='red',
166
- hist_bins: Union[int, str] = 'auto') -> None:
201
+ hist_bins: Union[int, str] = 'auto',
202
+ xtick_size: int=22,
203
+ ytick_size: int=22) -> None:
167
204
  """
168
205
  Initializes the formatting configuration for regression metrics.
169
206
 
@@ -181,6 +218,8 @@ class _BaseRegressionFormat:
181
218
  hist_bins (int | str): The number of bins for the residuals histogram.
182
219
  Defaults to 'auto' to use seaborn's automatic bin selection.
183
220
  - Options: 'auto', 'sqrt', 10, 20
221
+ xtick_size (int): Font size for x-axis tick labels.
222
+ ytick_size (int): Font size for y-axis tick labels.
184
223
 
185
224
  <br>
186
225
 
@@ -192,6 +231,8 @@ class _BaseRegressionFormat:
192
231
  self.ideal_line_color = ideal_line_color
193
232
  self.residual_line_color = residual_line_color
194
233
  self.hist_bins = hist_bins
234
+ self.xtick_size = xtick_size
235
+ self.ytick_size = ytick_size
195
236
 
196
237
  def __repr__(self) -> str:
197
238
  parts = [
@@ -200,7 +241,9 @@ class _BaseRegressionFormat:
200
241
  f"scatter_alpha={self.scatter_alpha}",
201
242
  f"ideal_line_color='{self.ideal_line_color}'",
202
243
  f"residual_line_color='{self.residual_line_color}'",
203
- f"hist_bins='{self.hist_bins}'"
244
+ f"hist_bins='{self.hist_bins}'",
245
+ f"xtick_size={self.xtick_size}",
246
+ f"ytick_size={self.ytick_size}"
204
247
  ]
205
248
  return f"{self.__class__.__name__}({', '.join(parts)})"
206
249
 
@@ -248,7 +291,7 @@ class _BaseSequenceValueFormat:
248
291
  [PRIVATE] Base configuration for sequence to value metrics.
249
292
  """
250
293
  def __init__(self,
251
- font_size: int=16,
294
+ font_size: int=25,
252
295
  scatter_color: str='tab:blue',
253
296
  scatter_alpha: float=0.6,
254
297
  ideal_line_color: str='k',
@@ -300,8 +343,7 @@ class _BaseSequenceSequenceFormat:
300
343
  [PRIVATE] Base configuration for sequence-to-sequence metrics.
301
344
  """
302
345
  def __init__(self,
303
- font_size: int = 16,
304
- plot_figsize: tuple[int, int] = (10, 6),
346
+ font_size: int = 25,
305
347
  grid_style: str = '--',
306
348
  rmse_color: str = 'tab:blue',
307
349
  rmse_marker: str = 'o-',
@@ -312,7 +354,6 @@ class _BaseSequenceSequenceFormat:
312
354
 
313
355
  Args:
314
356
  font_size (int): The base font size to apply to the plots.
315
- plot_figsize (Tuple[int, int]): Figure size for the plot.
316
357
  grid_style (str): Matplotlib linestyle for the plot grid.
317
358
  - Options: '--' (dashed), ':' (dotted), '-.' (dash-dot), '-' (solid)
318
359
  rmse_color (str): Matplotlib color for the RMSE line.
@@ -337,7 +378,6 @@ class _BaseSequenceSequenceFormat:
337
378
  ### [Matplotlib Markers](https://matplotlib.org/stable/api/markers_api.html)
338
379
  """
339
380
  self.font_size = font_size
340
- self.plot_figsize = plot_figsize
341
381
  self.grid_style = grid_style
342
382
  self.rmse_color = rmse_color
343
383
  self.rmse_marker = rmse_marker
@@ -347,7 +387,6 @@ class _BaseSequenceSequenceFormat:
347
387
  def __repr__(self) -> str:
348
388
  parts = [
349
389
  f"font_size={self.font_size}",
350
- f"plot_figsize={self.plot_figsize}",
351
390
  f"grid_style='{self.grid_style}'",
352
391
  f"rmse_color='{self.rmse_color}'",
353
392
  f"mae_color='{self.mae_color}'"
@@ -639,18 +678,22 @@ class RegressionMetricsFormat(_BaseRegressionFormat):
639
678
  Configuration for single-target regression.
640
679
  """
641
680
  def __init__(self,
642
- font_size: int=16,
681
+ font_size: int=26,
643
682
  scatter_color: str='tab:blue',
644
683
  scatter_alpha: float=0.6,
645
684
  ideal_line_color: str='k',
646
685
  residual_line_color: str='red',
647
- hist_bins: Union[int, str] = 'auto') -> None:
686
+ hist_bins: Union[int, str] = 'auto',
687
+ xtick_size: int=22,
688
+ ytick_size: int=22) -> None:
648
689
  super().__init__(font_size=font_size,
649
690
  scatter_color=scatter_color,
650
691
  scatter_alpha=scatter_alpha,
651
692
  ideal_line_color=ideal_line_color,
652
693
  residual_line_color=residual_line_color,
653
- hist_bins=hist_bins)
694
+ hist_bins=hist_bins,
695
+ xtick_size=xtick_size,
696
+ ytick_size=ytick_size)
654
697
 
655
698
 
656
699
  # Multitarget regression
@@ -659,18 +702,22 @@ class MultiTargetRegressionMetricsFormat(_BaseRegressionFormat):
659
702
  Configuration for multi-target regression.
660
703
  """
661
704
  def __init__(self,
662
- font_size: int=16,
705
+ font_size: int=26,
663
706
  scatter_color: str='tab:blue',
664
707
  scatter_alpha: float=0.6,
665
708
  ideal_line_color: str='k',
666
709
  residual_line_color: str='red',
667
- hist_bins: Union[int, str] = 'auto') -> None:
710
+ hist_bins: Union[int, str] = 'auto',
711
+ xtick_size: int=22,
712
+ ytick_size: int=22) -> None:
668
713
  super().__init__(font_size=font_size,
669
714
  scatter_color=scatter_color,
670
715
  scatter_alpha=scatter_alpha,
671
716
  ideal_line_color=ideal_line_color,
672
717
  residual_line_color=residual_line_color,
673
- hist_bins=hist_bins)
718
+ hist_bins=hist_bins,
719
+ xtick_size=xtick_size,
720
+ ytick_size=ytick_size)
674
721
 
675
722
 
676
723
  # Classification
@@ -682,11 +729,20 @@ class BinaryClassificationMetricsFormat(_BaseClassificationFormat):
682
729
  cmap: str="BuGn",
683
730
  ROC_PR_line: str='darkorange',
684
731
  calibration_bins: int=15,
685
- font_size: int=16) -> None:
732
+ font_size: int=26,
733
+ xtick_size: int=22,
734
+ ytick_size: int=22,
735
+ legend_size: int=26,
736
+ cm_font_size: int=26
737
+ ) -> None:
686
738
  super().__init__(cmap=cmap,
687
739
  ROC_PR_line=ROC_PR_line,
688
740
  calibration_bins=calibration_bins,
689
- font_size=font_size)
741
+ font_size=font_size,
742
+ xtick_size=xtick_size,
743
+ ytick_size=ytick_size,
744
+ legend_size=legend_size,
745
+ cm_font_size=cm_font_size)
690
746
 
691
747
 
692
748
  class MultiClassClassificationMetricsFormat(_BaseClassificationFormat):
@@ -697,12 +753,20 @@ class MultiClassClassificationMetricsFormat(_BaseClassificationFormat):
697
753
  cmap: str="BuGn",
698
754
  ROC_PR_line: str='darkorange',
699
755
  calibration_bins: int=15,
700
- font_size: int=16) -> None:
756
+ font_size: int=26,
757
+ xtick_size: int=22,
758
+ ytick_size: int=22,
759
+ legend_size: int=26,
760
+ cm_font_size: int=26
761
+ ) -> None:
701
762
  super().__init__(cmap=cmap,
702
763
  ROC_PR_line=ROC_PR_line,
703
764
  calibration_bins=calibration_bins,
704
- font_size=font_size)
705
-
765
+ font_size=font_size,
766
+ xtick_size=xtick_size,
767
+ ytick_size=ytick_size,
768
+ legend_size=legend_size,
769
+ cm_font_size=cm_font_size)
706
770
 
707
771
  class BinaryImageClassificationMetricsFormat(_BaseClassificationFormat):
708
772
  """
@@ -712,12 +776,20 @@ class BinaryImageClassificationMetricsFormat(_BaseClassificationFormat):
712
776
  cmap: str="BuGn",
713
777
  ROC_PR_line: str='darkorange',
714
778
  calibration_bins: int=15,
715
- font_size: int=16) -> None:
779
+ font_size: int=26,
780
+ xtick_size: int=22,
781
+ ytick_size: int=22,
782
+ legend_size: int=26,
783
+ cm_font_size: int=26
784
+ ) -> None:
716
785
  super().__init__(cmap=cmap,
717
786
  ROC_PR_line=ROC_PR_line,
718
787
  calibration_bins=calibration_bins,
719
- font_size=font_size)
720
-
788
+ font_size=font_size,
789
+ xtick_size=xtick_size,
790
+ ytick_size=ytick_size,
791
+ legend_size=legend_size,
792
+ cm_font_size=cm_font_size)
721
793
 
722
794
  class MultiClassImageClassificationMetricsFormat(_BaseClassificationFormat):
723
795
  """
@@ -727,12 +799,20 @@ class MultiClassImageClassificationMetricsFormat(_BaseClassificationFormat):
727
799
  cmap: str="BuGn",
728
800
  ROC_PR_line: str='darkorange',
729
801
  calibration_bins: int=15,
730
- font_size: int=16) -> None:
802
+ font_size: int=26,
803
+ xtick_size: int=22,
804
+ ytick_size: int=22,
805
+ legend_size: int=26,
806
+ cm_font_size: int=26
807
+ ) -> None:
731
808
  super().__init__(cmap=cmap,
732
809
  ROC_PR_line=ROC_PR_line,
733
810
  calibration_bins=calibration_bins,
734
- font_size=font_size)
735
-
811
+ font_size=font_size,
812
+ xtick_size=xtick_size,
813
+ ytick_size=ytick_size,
814
+ legend_size=legend_size,
815
+ cm_font_size=cm_font_size)
736
816
 
737
817
  # Multi-Label classification
738
818
  class MultiLabelBinaryClassificationMetricsFormat(_BaseMultiLabelFormat):
@@ -742,11 +822,17 @@ class MultiLabelBinaryClassificationMetricsFormat(_BaseMultiLabelFormat):
742
822
  def __init__(self,
743
823
  cmap: str = "BuGn",
744
824
  ROC_PR_line: str='darkorange',
745
- font_size: int = 16) -> None:
825
+ font_size: int = 25,
826
+ xtick_size: int=20,
827
+ ytick_size: int=20,
828
+ legend_size: int=23
829
+ ) -> None:
746
830
  super().__init__(cmap=cmap,
747
831
  ROC_PR_line=ROC_PR_line,
748
- font_size=font_size)
749
-
832
+ font_size=font_size,
833
+ xtick_size=xtick_size,
834
+ ytick_size=ytick_size,
835
+ legend_size=legend_size)
750
836
 
751
837
  # Segmentation
752
838
  class BinarySegmentationMetricsFormat(_BaseSegmentationFormat):
@@ -781,7 +867,7 @@ class SequenceValueMetricsFormat(_BaseSequenceValueFormat):
781
867
  Configuration for sequence-to-value prediction.
782
868
  """
783
869
  def __init__(self,
784
- font_size: int=16,
870
+ font_size: int=25,
785
871
  scatter_color: str='tab:blue',
786
872
  scatter_alpha: float=0.6,
787
873
  ideal_line_color: str='k',
@@ -800,15 +886,13 @@ class SequenceSequenceMetricsFormat(_BaseSequenceSequenceFormat):
800
886
  Configuration for sequence-to-sequence prediction.
801
887
  """
802
888
  def __init__(self,
803
- font_size: int = 16,
804
- plot_figsize: tuple[int, int] = (10, 6),
889
+ font_size: int = 25,
805
890
  grid_style: str = '--',
806
891
  rmse_color: str = 'tab:blue',
807
892
  rmse_marker: str = 'o-',
808
893
  mae_color: str = 'tab:orange',
809
894
  mae_marker: str = 's--'):
810
895
  super().__init__(font_size=font_size,
811
- plot_figsize=plot_figsize,
812
896
  grid_style=grid_style,
813
897
  rmse_color=rmse_color,
814
898
  rmse_marker=rmse_marker,
@@ -48,6 +48,7 @@ __all__ = [
48
48
 
49
49
  DPI_value = _EvaluationConfig.DPI
50
50
  REGRESSION_PLOT_SIZE = _EvaluationConfig.REGRESSION_PLOT_SIZE
51
+ CLASSIFICATION_PLOT_SIZE = _EvaluationConfig.CLASSIFICATION_PLOT_SIZE
51
52
 
52
53
 
53
54
  def plot_losses(history: dict, save_dir: Union[str, Path]):
@@ -67,7 +68,7 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
67
68
  _LOGGER.warning("Loss history is empty or incomplete. Cannot plot.")
68
69
  return
69
70
 
70
- fig, ax = plt.subplots(figsize=(10, 5), dpi=DPI_value)
71
+ fig, ax = plt.subplots(figsize=_EvaluationConfig.LOSS_PLOT_SIZE, dpi=DPI_value)
71
72
 
72
73
  # --- Plot Losses (Left Y-axis) ---
73
74
  line_handles = [] # To store line objects for the legend
@@ -84,10 +85,11 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
84
85
  line2, = ax.plot(epochs, val_loss, 'o-', label='Validation Loss', color='tab:orange')
85
86
  line_handles.append(line2)
86
87
 
87
- ax.set_title('Training and Validation Loss')
88
- ax.set_xlabel('Epochs')
89
- ax.set_ylabel('Loss', color='tab:blue')
90
- ax.tick_params(axis='y', labelcolor='tab:blue')
88
+ ax.set_title('Training and Validation Loss', fontsize=_EvaluationConfig.LOSS_PLOT_LABEL_SIZE + 2, pad=_EvaluationConfig.LABEL_PADDING)
89
+ ax.set_xlabel('Epochs', fontsize=_EvaluationConfig.LOSS_PLOT_LABEL_SIZE, labelpad=_EvaluationConfig.LABEL_PADDING)
90
+ ax.set_ylabel('Loss', color='tab:blue', fontsize=_EvaluationConfig.LOSS_PLOT_LABEL_SIZE, labelpad=_EvaluationConfig.LABEL_PADDING)
91
+ ax.tick_params(axis='y', labelcolor='tab:blue', labelsize=_EvaluationConfig.LOSS_PLOT_TICK_SIZE)
92
+ ax.tick_params(axis='x', labelsize=_EvaluationConfig.LOSS_PLOT_TICK_SIZE)
91
93
  ax.grid(True, linestyle='--')
92
94
 
93
95
  # --- Plot Learning Rate (Right Y-axis) ---
@@ -97,13 +99,17 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
97
99
  line3, = ax2.plot(epochs, lr_history, 'g--', label='Learning Rate')
98
100
  line_handles.append(line3)
99
101
 
100
- ax2.set_ylabel('Learning Rate', color='g')
101
- ax2.tick_params(axis='y', labelcolor='g')
102
+ ax2.set_ylabel('Learning Rate', color='g', fontsize=_EvaluationConfig.LOSS_PLOT_LABEL_SIZE, labelpad=_EvaluationConfig.LABEL_PADDING)
103
+ ax2.tick_params(axis='y', labelcolor='g', labelsize=_EvaluationConfig.LOSS_PLOT_TICK_SIZE)
102
104
  # Use scientific notation if the LR is very small
103
105
  ax2.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
106
+ # increase the size of the scientific notation
107
+ ax2.yaxis.get_offset_text().set_fontsize(_EvaluationConfig.LOSS_PLOT_TICK_SIZE - 2)
108
+ # remove grid from second y-axis
109
+ ax2.grid(False)
104
110
 
105
111
  # Combine legends from both axes
106
- ax.legend(handles=line_handles, loc='best')
112
+ ax.legend(handles=line_handles, loc='best', fontsize=_EvaluationConfig.LOSS_PLOT_LEGEND_SIZE)
107
113
 
108
114
  # ax.grid(True)
109
115
  plt.tight_layout()
@@ -142,10 +148,17 @@ def classification_metrics(save_dir: Union[str, Path],
142
148
  else:
143
149
  format_config = config
144
150
 
145
- original_rc_params = plt.rcParams.copy()
146
- plt.rcParams.update({'font.size': format_config.font_size})
151
+ # original_rc_params = plt.rcParams.copy()
152
+ # plt.rcParams.update({'font.size': format_config.font_size})
147
153
 
148
- # print("--- Classification Report ---")
154
+ # --- Set Font Sizes ---
155
+ xtick_size = format_config.xtick_size
156
+ ytick_size = format_config.ytick_size
157
+ legend_size = format_config.legend_size
158
+
159
+ # config font size for heatmap
160
+ cm_font_size = format_config.cm_font_size
161
+ cm_tick_size = cm_font_size - 4
149
162
 
150
163
  # --- Parse class_map ---
151
164
  map_labels = None
@@ -176,61 +189,122 @@ def classification_metrics(save_dir: Union[str, Path],
176
189
  try:
177
190
  # Create DataFrame from report
178
191
  report_df = pd.DataFrame(report_dict)
179
-
180
- # 1. Drop the 'accuracy' column (single float)
181
- if 'accuracy' in report_df.columns:
182
- report_df = report_df.drop(columns=['accuracy'])
183
-
184
- # 2. Select all metric rows *except* the last one ('support')
185
- # 3. Transpose the DataFrame
186
- plot_df = report_df.iloc[:-1, :].T
187
-
188
- fig_height = max(5.0, len(plot_df.index) * 0.5 + 2.0)
189
- plt.figure(figsize=(7, fig_height), dpi=DPI_value)
190
192
 
191
- sns.set_theme(font_scale=1.2) # Scale seaborn font
193
+ # 1. Robust Cleanup: Drop by name, not position
194
+ # Remove 'accuracy' column if it exists (handles the scalar value issue)
195
+ report_df = report_df.drop(columns=['accuracy'], errors='ignore')
196
+
197
+ # Remove 'support' row explicitly (safer than iloc[:-1])
198
+ if 'support' in report_df.index:
199
+ report_df = report_df.drop(index='support')
200
+
201
+ # 2. Transpose: Rows = Classes, Cols = Metrics
202
+ plot_df = report_df.T
203
+
204
+ # 3. Dynamic Height Calculation
205
+ # (Base height of 4 + 0.5 inches per class row)
206
+ fig_height = max(5.0, len(plot_df.index) * 0.5 + 4.0)
207
+ fig_width = 8.0 # Set a fixed width
208
+
209
+ # --- Use calculated dimensions, not the config constant ---
210
+ fig_heat, ax_heat = plt.subplots(figsize=(fig_width, fig_height), dpi=_EvaluationConfig.DPI)
211
+
212
+ # sns.set_theme(font_scale=1.4)
192
213
  sns.heatmap(plot_df,
193
214
  annot=True,
194
215
  cmap=format_config.cmap,
195
216
  fmt='.2f',
196
217
  vmin=0.0,
197
- vmax=1.0)
198
- sns.set_theme(font_scale=1.0) # Reset seaborn scale
199
- plt.title("Classification Report Heatmap")
218
+ vmax=1.0,
219
+ cbar_kws={'shrink': 0.9}) # Shrink colorbar slightly to fit better
220
+
221
+ # sns.set_theme(font_scale=1.0)
222
+
223
+ ax_heat.set_title("Classification Report Heatmap", pad=_EvaluationConfig.LABEL_PADDING, fontsize=cm_font_size)
224
+
225
+ # manually increase the font size of the elements
226
+ for text in ax_heat.texts:
227
+ text.set_fontsize(cm_tick_size)
228
+
229
+ # manually increase the size of the colorbar ticks
230
+ cbar = ax_heat.collections[0].colorbar
231
+ cbar.ax.tick_params(labelsize=cm_tick_size - 4) # type: ignore
232
+
233
+ # Update Ticks
234
+ ax_heat.tick_params(axis='x', labelsize=cm_tick_size, pad=_EvaluationConfig.LABEL_PADDING)
235
+ ax_heat.tick_params(axis='y', labelsize=cm_tick_size, pad=_EvaluationConfig.LABEL_PADDING, rotation=0) # Ensure Y labels are horizontal
236
+
200
237
  plt.tight_layout()
238
+
201
239
  heatmap_path = save_dir_path / "classification_report_heatmap.svg"
202
240
  plt.savefig(heatmap_path)
203
241
  _LOGGER.info(f"📊 Report heatmap saved as '{heatmap_path.name}'")
204
- plt.close()
242
+ plt.close(fig_heat)
243
+
205
244
  except Exception as e:
206
245
  _LOGGER.error(f"Could not generate classification report heatmap: {e}")
207
-
246
+
208
247
  # --- labels for Confusion Matrix ---
209
248
  plot_labels = map_labels
210
249
  plot_display_labels = map_display_labels
211
250
 
212
- # Save Confusion Matrix
213
- fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
251
+ # 1. DYNAMIC SIZE CALCULATION
252
+ # Calculate figure size based on number of classes.
253
+ n_classes = len(plot_labels) if plot_labels is not None else len(np.unique(y_true))
254
+ # Ensure a minimum size so very small matrices aren't tiny
255
+ fig_w = max(9, n_classes * 0.8 + 3)
256
+ fig_h = max(8, n_classes * 0.8 + 2)
257
+
258
+ # Use the calculated size instead of CLASSIFICATION_PLOT_SIZE
259
+ fig_cm, ax_cm = plt.subplots(figsize=(fig_w, fig_h), dpi=DPI_value)
214
260
  disp_ = ConfusionMatrixDisplay.from_predictions(y_true,
215
261
  y_pred,
216
262
  cmap=format_config.cmap,
217
263
  ax=ax_cm,
218
264
  normalize='true',
219
265
  labels=plot_labels,
220
- display_labels=plot_display_labels)
266
+ display_labels=plot_display_labels,
267
+ colorbar=False)
221
268
 
222
269
  disp_.im_.set_clim(vmin=0.0, vmax=1.0)
223
270
 
224
271
  # Turn off gridlines
225
272
  ax_cm.grid(False)
226
273
 
227
- # Manually update font size of cell texts
274
+ # 2. CHECK FOR FONT CLASH
275
+ # If matrix is huge, force text smaller. If small, allow user config.
276
+ final_font_size = cm_font_size + 2
277
+ if n_classes > 2:
278
+ final_font_size = cm_font_size - n_classes # Decrease font size for larger matrices
279
+
228
280
  for text in ax_cm.texts:
229
- text.set_fontsize(format_config.font_size)
281
+ text.set_fontsize(final_font_size)
282
+
283
+ # Update Ticks for Confusion Matrix
284
+ ax_cm.tick_params(axis='x', labelsize=cm_tick_size)
285
+ ax_cm.tick_params(axis='y', labelsize=cm_tick_size)
286
+
287
+ #if more than 3 classes, rotate x ticks
288
+ if n_classes > 3:
289
+ plt.setp(ax_cm.get_xticklabels(), rotation=45, ha='right', rotation_mode="anchor")
230
290
 
291
+ # Set titles and labels with padding
292
+ ax_cm.set_title("Confusion Matrix", pad=_EvaluationConfig.LABEL_PADDING, fontsize=cm_font_size + 2)
293
+ ax_cm.set_xlabel(ax_cm.get_xlabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=cm_font_size)
294
+ ax_cm.set_ylabel(ax_cm.get_ylabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=cm_font_size)
295
+
296
+ # --- ADJUST COLORBAR FONT & SIZE---
297
+ # Manually add the colorbar with the 'shrink' parameter
298
+ cbar = fig_cm.colorbar(disp_.im_, ax=ax_cm, shrink=0.8)
299
+
300
+ # Update the tick size on the new cbar object
301
+ cbar.ax.tick_params(labelsize=cm_tick_size)
302
+
303
+ # (Optional) add a label to the bar itself (e.g. "Probability")
304
+ # cbar.set_label('Probability', fontsize=12)
305
+
231
306
  fig_cm.tight_layout()
232
307
 
233
- ax_cm.set_title("Confusion Matrix")
234
308
  cm_path = save_dir_path / "confusion_matrix.svg"
235
309
  plt.savefig(cm_path)
236
310
  _LOGGER.info(f"❇️ Confusion matrix saved as '{cm_path.name}'")
@@ -335,34 +409,50 @@ def classification_metrics(save_dir: Union[str, Path],
335
409
  # Calculate AUC.
336
410
  auc = roc_auc_score(y_true_binary, y_score)
337
411
 
338
- fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
412
+ fig_roc, ax_roc = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
339
413
  ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=format_config.ROC_PR_line)
340
414
  ax_roc.plot([0, 1], [0, 1], 'k--')
341
- ax_roc.set_title(f'Receiver Operating Characteristic{plot_title}')
342
- ax_roc.set_xlabel('False Positive Rate')
343
- ax_roc.set_ylabel('True Positive Rate')
344
- ax_roc.legend(loc='lower right')
415
+ ax_roc.set_title(f'Receiver Operating Characteristic{plot_title}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size + 2)
416
+ ax_roc.set_xlabel('False Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
417
+ ax_roc.set_ylabel('True Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
418
+
419
+ # Apply Ticks and Legend sizing
420
+ ax_roc.tick_params(axis='x', labelsize=xtick_size)
421
+ ax_roc.tick_params(axis='y', labelsize=ytick_size)
422
+ ax_roc.legend(loc='lower right', fontsize=legend_size)
423
+
345
424
  ax_roc.grid(True)
346
425
  roc_path = save_dir_path / f"roc_curve{save_suffix}.svg"
426
+
427
+ plt.tight_layout()
428
+
347
429
  plt.savefig(roc_path)
348
430
  plt.close(fig_roc)
349
431
 
350
432
  # --- Save Precision-Recall Curve ---
351
433
  precision, recall, _ = precision_recall_curve(y_true_binary, y_score)
352
434
  ap_score = average_precision_score(y_true_binary, y_score)
353
- fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
435
+ fig_pr, ax_pr = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
354
436
  ax_pr.plot(recall, precision, label=f'Avg Precision = {ap_score:.2f}', color=format_config.ROC_PR_line)
355
- ax_pr.set_title(f'Precision-Recall Curve{plot_title}')
356
- ax_pr.set_xlabel('Recall')
357
- ax_pr.set_ylabel('Precision')
358
- ax_pr.legend(loc='lower left')
437
+ ax_pr.set_title(f'Precision-Recall Curve{plot_title}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size + 2)
438
+ ax_pr.set_xlabel('Recall', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
439
+ ax_pr.set_ylabel('Precision', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
440
+
441
+ # Apply Ticks and Legend sizing
442
+ ax_pr.tick_params(axis='x', labelsize=xtick_size)
443
+ ax_pr.tick_params(axis='y', labelsize=ytick_size)
444
+ ax_pr.legend(loc='lower left', fontsize=legend_size)
445
+
359
446
  ax_pr.grid(True)
360
447
  pr_path = save_dir_path / f"pr_curve{save_suffix}.svg"
448
+
449
+ plt.tight_layout()
450
+
361
451
  plt.savefig(pr_path)
362
452
  plt.close(fig_pr)
363
453
 
364
454
  # --- Save Calibration Plot ---
365
- fig_cal, ax_cal = plt.subplots(figsize=(8, 8), dpi=DPI_value)
455
+ fig_cal, ax_cal = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
366
456
 
367
457
  # --- Step 1: Get binned data *without* plotting ---
368
458
  with plt.ioff(): # Suppress showing the temporary plot
@@ -386,7 +476,7 @@ def classification_metrics(save_dir: Union[str, Path],
386
476
  y=line_y,
387
477
  ax=ax_cal,
388
478
  scatter=False,
389
- label=f"Calibration Curve ({format_config.calibration_bins} bins)",
479
+ label=f"Model calibration",
390
480
  line_kws={
391
481
  'color': format_config.ROC_PR_line,
392
482
  'linestyle': '--',
@@ -394,15 +484,19 @@ def classification_metrics(save_dir: Union[str, Path],
394
484
  }
395
485
  )
396
486
 
397
- ax_cal.set_title(f'Reliability Curve{plot_title}')
398
- ax_cal.set_xlabel('Mean Predicted Probability')
399
- ax_cal.set_ylabel('Fraction of Positives')
487
+ ax_cal.set_title(f'Reliability Curve{plot_title}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size + 2)
488
+ ax_cal.set_xlabel('Mean Predicted Probability', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
489
+ ax_cal.set_ylabel('Fraction of Positives', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
400
490
 
401
491
  # --- Step 3: Set final limits *after* plotting ---
402
492
  ax_cal.set_ylim(0.0, 1.0)
403
493
  ax_cal.set_xlim(0.0, 1.0)
404
494
 
405
- ax_cal.legend(loc='lower right')
495
+ # Apply Ticks and Legend sizing
496
+ ax_cal.tick_params(axis='x', labelsize=xtick_size)
497
+ ax_cal.tick_params(axis='y', labelsize=ytick_size)
498
+ ax_cal.legend(loc='lower right', fontsize=legend_size)
499
+
406
500
  ax_cal.grid(True)
407
501
  plt.tight_layout()
408
502
 
@@ -413,7 +507,7 @@ def classification_metrics(save_dir: Union[str, Path],
413
507
  _LOGGER.info(f"📈 Saved {len(class_indices_to_plot)} sets of ROC, Precision-Recall, and Calibration plots.")
414
508
 
415
509
  # restore RC params
416
- plt.rcParams.update(original_rc_params)
510
+ # plt.rcParams.update(original_rc_params)
417
511
 
418
512
 
419
513
  def regression_metrics(
@@ -440,8 +534,13 @@ def regression_metrics(
440
534
  format_config = config
441
535
 
442
536
  # --- Set Matplotlib font size ---
443
- original_rc_params = plt.rcParams.copy()
444
- plt.rcParams.update({'font.size': format_config.font_size})
537
+ # original_rc_params = plt.rcParams.copy()
538
+ # plt.rcParams.update({'font.size': format_config.font_size})
539
+
540
+ # --- Resolve Font Sizes ---
541
+ xtick_size = format_config.xtick_size
542
+ ytick_size = format_config.ytick_size
543
+ base_font_size = format_config.font_size
445
544
 
446
545
  # --- Calculate Metrics ---
447
546
  rmse = np.sqrt(mean_squared_error(y_true, y_pred))
@@ -472,9 +571,14 @@ def regression_metrics(
472
571
  alpha=format_config.scatter_alpha,
473
572
  color=format_config.scatter_color)
474
573
  ax_res.axhline(0, color=format_config.residual_line_color, linestyle='--')
475
- ax_res.set_xlabel("Predicted Values")
476
- ax_res.set_ylabel("Residuals")
477
- ax_res.set_title("Residual Plot")
574
+ ax_res.set_xlabel("Predicted Values", labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
575
+ ax_res.set_ylabel("Residuals", labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
576
+ ax_res.set_title("Residual Plot", pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
577
+
578
+ # Apply Ticks
579
+ ax_res.tick_params(axis='x', labelsize=xtick_size)
580
+ ax_res.tick_params(axis='y', labelsize=ytick_size)
581
+
478
582
  ax_res.grid(True)
479
583
  plt.tight_layout()
480
584
  res_path = save_dir_path / "residual_plot.svg"
@@ -491,9 +595,14 @@ def regression_metrics(
491
595
  linestyle='--',
492
596
  lw=2,
493
597
  color=format_config.ideal_line_color)
494
- ax_tvp.set_xlabel('True Values')
495
- ax_tvp.set_ylabel('Predictions')
496
- ax_tvp.set_title('True vs. Predicted Values')
598
+ ax_tvp.set_xlabel('True Values', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
599
+ ax_tvp.set_ylabel('Predictions', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
600
+ ax_tvp.set_title('True vs. Predicted Values', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
601
+
602
+ # Apply Ticks
603
+ ax_tvp.tick_params(axis='x', labelsize=xtick_size)
604
+ ax_tvp.tick_params(axis='y', labelsize=ytick_size)
605
+
497
606
  ax_tvp.grid(True)
498
607
  plt.tight_layout()
499
608
  tvp_path = save_dir_path / "true_vs_predicted_plot.svg"
@@ -506,9 +615,14 @@ def regression_metrics(
506
615
  sns.histplot(residuals, kde=True, ax=ax_hist,
507
616
  bins=format_config.hist_bins,
508
617
  color=format_config.scatter_color)
509
- ax_hist.set_xlabel("Residual Value")
510
- ax_hist.set_ylabel("Frequency")
511
- ax_hist.set_title("Distribution of Residuals")
618
+ ax_hist.set_xlabel("Residual Value", labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
619
+ ax_hist.set_ylabel("Frequency", labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
620
+ ax_hist.set_title("Distribution of Residuals", pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
621
+
622
+ # Apply Ticks
623
+ ax_hist.tick_params(axis='x', labelsize=xtick_size)
624
+ ax_hist.tick_params(axis='y', labelsize=ytick_size)
625
+
512
626
  ax_hist.grid(True)
513
627
  plt.tight_layout()
514
628
  hist_path = save_dir_path / "residuals_histogram.svg"
@@ -517,7 +631,7 @@ def regression_metrics(
517
631
  plt.close(fig_hist)
518
632
 
519
633
  # --- Restore RC params ---
520
- plt.rcParams.update(original_rc_params)
634
+ # plt.rcParams.update(original_rc_params)
521
635
 
522
636
 
523
637
  def shap_summary_plot(model,
@@ -44,6 +44,7 @@ __all__ = [
44
44
 
45
45
  DPI_value = _EvaluationConfig.DPI
46
46
  REGRESSION_PLOT_SIZE = _EvaluationConfig.REGRESSION_PLOT_SIZE
47
+ CLASSIFICATION_PLOT_SIZE = _EvaluationConfig.CLASSIFICATION_PLOT_SIZE
47
48
 
48
49
 
49
50
  def multi_target_regression_metrics(
@@ -88,8 +89,13 @@ def multi_target_regression_metrics(
88
89
  format_config = config
89
90
 
90
91
  # --- Set Matplotlib font size ---
91
- original_rc_params = plt.rcParams.copy()
92
- plt.rcParams.update({'font.size': format_config.font_size})
92
+ # original_rc_params = plt.rcParams.copy()
93
+ # plt.rcParams.update({'font.size': format_config.font_size})
94
+
95
+ # ticks font sizes
96
+ xtick_size = format_config.xtick_size
97
+ ytick_size = format_config.ytick_size
98
+ base_font_size = format_config.font_size
93
99
 
94
100
  _LOGGER.debug("--- Multi-Target Regression Evaluation ---")
95
101
 
@@ -105,11 +111,11 @@ def multi_target_regression_metrics(
105
111
  r2 = r2_score(true_i, pred_i)
106
112
  medae = median_absolute_error(true_i, pred_i)
107
113
  metrics_summary.append({
108
- 'target': name,
109
- 'rmse': rmse,
110
- 'mae': mae,
111
- 'r2_score': r2,
112
- 'median_abs_error': medae
114
+ 'Target': name,
115
+ 'RMSE': rmse,
116
+ 'MAE': mae,
117
+ 'MedAE': medae,
118
+ 'R2-score': r2,
113
119
  })
114
120
 
115
121
  # --- Save Residual Plot ---
@@ -121,9 +127,14 @@ def multi_target_regression_metrics(
121
127
  s=50,
122
128
  color=format_config.scatter_color) # Use config color
123
129
  ax_res.axhline(0, color=format_config.residual_line_color, linestyle='--') # Use config color
124
- ax_res.set_xlabel("Predicted Values")
125
- ax_res.set_ylabel("Residuals")
126
- ax_res.set_title(f"Residual Plot for '{name}'")
130
+ ax_res.set_xlabel("Predicted Values", labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
131
+ ax_res.set_ylabel("Residuals", labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
132
+ ax_res.set_title(f"Residual Plot for '{name}'", pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
133
+
134
+ # Apply Ticks
135
+ ax_res.tick_params(axis='x', labelsize=xtick_size)
136
+ ax_res.tick_params(axis='y', labelsize=ytick_size)
137
+
127
138
  ax_res.grid(True, linestyle='--', alpha=0.6)
128
139
  plt.tight_layout()
129
140
  res_path = save_dir_path / f"residual_plot_{sanitized_name}.svg"
@@ -141,9 +152,14 @@ def multi_target_regression_metrics(
141
152
  linestyle='--',
142
153
  lw=2,
143
154
  color=format_config.ideal_line_color) # Use config color
144
- ax_tvp.set_xlabel('True Values')
145
- ax_tvp.set_ylabel('Predicted Values')
146
- ax_tvp.set_title(f"True vs. Predicted for '{name}'")
155
+ ax_tvp.set_xlabel('True Values', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
156
+ ax_tvp.set_ylabel('Predicted Values', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
157
+ ax_tvp.set_title(f"True vs. Predicted for '{name}'", pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
158
+
159
+ # Apply Ticks
160
+ ax_tvp.tick_params(axis='x', labelsize=xtick_size)
161
+ ax_tvp.tick_params(axis='y', labelsize=ytick_size)
162
+
147
163
  ax_tvp.grid(True, linestyle='--', alpha=0.6)
148
164
  plt.tight_layout()
149
165
  tvp_path = save_dir_path / f"true_vs_predicted_plot_{sanitized_name}.svg"
@@ -157,7 +173,7 @@ def multi_target_regression_metrics(
157
173
  _LOGGER.info(f"Full regression report saved to '{report_path.name}'")
158
174
 
159
175
  # --- Restore RC params ---
160
- plt.rcParams.update(original_rc_params)
176
+ # plt.rcParams.update(original_rc_params)
161
177
 
162
178
 
163
179
  def multi_label_classification_metrics(
@@ -205,10 +221,14 @@ def multi_label_classification_metrics(
205
221
  # y_pred is now passed in directly, no threshold needed.
206
222
 
207
223
  # --- Save current RC params and update font size ---
208
- original_rc_params = plt.rcParams.copy()
209
- plt.rcParams.update({'font.size': format_config.font_size})
224
+ # original_rc_params = plt.rcParams.copy()
225
+ # plt.rcParams.update({'font.size': format_config.font_size})
210
226
 
211
- # _LOGGER.info("--- Multi-Label Classification Evaluation ---")
227
+ # ticks and legend font sizes
228
+ xtick_size = format_config.xtick_size
229
+ ytick_size = format_config.ytick_size
230
+ legend_size = format_config.legend_size
231
+ base_font_size = format_config.font_size
212
232
 
213
233
  # --- Calculate and Save Overall Metrics (using y_pred) ---
214
234
  h_loss = hamming_loss(y_true, y_pred)
@@ -224,7 +244,7 @@ def multi_label_classification_metrics(
224
244
  f"--------------------------------------------------\n"
225
245
  )
226
246
  # print(overall_report)
227
- overall_report_path = save_dir_path / "classification_report_overall.txt"
247
+ overall_report_path = save_dir_path / "classification_report.txt"
228
248
  overall_report_path.write_text(overall_report)
229
249
 
230
250
  # --- Per-Label Metrics and Plots ---
@@ -241,14 +261,15 @@ def multi_label_classification_metrics(
241
261
  report_path.write_text(report_text) # type: ignore
242
262
 
243
263
  # --- Save Confusion Matrix (uses y_pred) ---
244
- fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
264
+ fig_cm, ax_cm = plt.subplots(figsize=_EvaluationConfig.CM_SIZE, dpi=_EvaluationConfig.DPI)
245
265
  disp_ = ConfusionMatrixDisplay.from_predictions(true_i,
246
266
  pred_i,
247
267
  cmap=format_config.cmap, # Use config cmap
248
268
  ax=ax_cm,
249
269
  normalize='true',
250
270
  labels=[0, 1],
251
- display_labels=["Negative", "Positive"])
271
+ display_labels=["Negative", "Positive"],
272
+ colorbar=False)
252
273
 
253
274
  disp_.im_.set_clim(vmin=0.0, vmax=1.0)
254
275
 
@@ -257,11 +278,26 @@ def multi_label_classification_metrics(
257
278
 
258
279
  # Manually update font size of cell texts
259
280
  for text in ax_cm.texts:
260
- text.set_fontsize(format_config.font_size) # Use config font_size
281
+ text.set_fontsize(base_font_size + 2) # Use config font_size
282
+
283
+ # Apply ticks
284
+ ax_cm.tick_params(axis='x', labelsize=xtick_size)
285
+ ax_cm.tick_params(axis='y', labelsize=ytick_size)
286
+
287
+ # Set titles and labels with padding
288
+ ax_cm.set_title(f"Confusion Matrix for '{name}'", pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
289
+ ax_cm.set_xlabel(ax_cm.get_xlabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
290
+ ax_cm.set_ylabel(ax_cm.get_ylabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
291
+
292
+ # --- ADJUST COLORBAR FONT & SIZE---
293
+ # Manually add the colorbar with the 'shrink' parameter
294
+ cbar = fig_cm.colorbar(disp_.im_, ax=ax_cm, shrink=0.8)
261
295
 
262
- fig_cm.tight_layout()
296
+ # Update the tick size on the new cbar object
297
+ cbar.ax.tick_params(labelsize=ytick_size) # type: ignore
298
+
299
+ plt.tight_layout()
263
300
 
264
- ax_cm.set_title(f"Confusion Matrix for '{name}'")
265
301
  cm_path = save_dir_path / f"confusion_matrix_{sanitized_name}.svg"
266
302
  plt.savefig(cm_path)
267
303
  plt.close(fig_cm)
@@ -302,12 +338,23 @@ def multi_label_classification_metrics(
302
338
  _LOGGER.warning(f"Could not calculate or save optimal threshold for '{name}': {e}")
303
339
 
304
340
  auc = roc_auc_score(true_i, prob_i)
305
- fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
341
+ fig_roc, ax_roc = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
306
342
  ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=format_config.ROC_PR_line) # Use config color
307
343
  ax_roc.plot([0, 1], [0, 1], 'k--')
308
- ax_roc.set_title(f'ROC Curve for "{name}"')
309
- ax_roc.set_xlabel('False Positive Rate'); ax_roc.set_ylabel('True Positive Rate')
310
- ax_roc.legend(loc='lower right'); ax_roc.grid(True, linestyle='--', alpha=0.6)
344
+
345
+ ax_roc.set_title(f'ROC Curve for "{name}"', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
346
+ ax_roc.set_xlabel('False Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
347
+ ax_roc.set_ylabel('True Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
348
+
349
+ # Apply ticks and legend font size
350
+ ax_roc.tick_params(axis='x', labelsize=xtick_size)
351
+ ax_roc.tick_params(axis='y', labelsize=ytick_size)
352
+ ax_roc.legend(loc='lower right', fontsize=legend_size)
353
+
354
+ ax_roc.grid(True, linestyle='--', alpha=0.6)
355
+
356
+ plt.tight_layout()
357
+
311
358
  roc_path = save_dir_path / f"roc_curve_{sanitized_name}.svg"
312
359
  plt.savefig(roc_path)
313
360
  plt.close(fig_roc)
@@ -315,17 +362,27 @@ def multi_label_classification_metrics(
315
362
  # --- Save Precision-Recall Curve (uses y_prob) ---
316
363
  precision, recall, _ = precision_recall_curve(true_i, prob_i)
317
364
  ap_score = average_precision_score(true_i, prob_i)
318
- fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
365
+ fig_pr, ax_pr = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
319
366
  ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=format_config.ROC_PR_line) # Use config color
320
- ax_pr.set_title(f'Precision-Recall Curve for "{name}"')
321
- ax_pr.set_xlabel('Recall'); ax_pr.set_ylabel('Precision')
322
- ax_pr.legend(loc='lower left'); ax_pr.grid(True, linestyle='--', alpha=0.6)
367
+ ax_pr.set_title(f'Precision-Recall Curve for "{name}"', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
368
+ ax_pr.set_xlabel('Recall', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
369
+ ax_pr.set_ylabel('Precision', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
370
+
371
+ # Apply ticks and legend font size
372
+ ax_pr.tick_params(axis='x', labelsize=xtick_size)
373
+ ax_pr.tick_params(axis='y', labelsize=ytick_size)
374
+ ax_pr.legend(loc='lower left', fontsize=legend_size)
375
+
376
+ ax_pr.grid(True, linestyle='--', alpha=0.6)
377
+
378
+ fig_pr.tight_layout()
379
+
323
380
  pr_path = save_dir_path / f"pr_curve_{sanitized_name}.svg"
324
381
  plt.savefig(pr_path)
325
382
  plt.close(fig_pr)
326
383
 
327
384
  # restore RC params
328
- plt.rcParams.update(original_rc_params)
385
+ # plt.rcParams.update(original_rc_params)
329
386
 
330
387
  _LOGGER.info(f"All individual label reports and plots saved to '{save_dir_path.name}'")
331
388
 
@@ -186,7 +186,7 @@ def sequence_to_sequence_metrics(
186
186
  _LOGGER.info(f"📝 Seq-to-Seq per-step report saved as '{report_path.name}'")
187
187
 
188
188
  # --- Create and save plot ---
189
- fig, ax1 = plt.subplots(figsize=format_config.plot_figsize, dpi=DPI_value)
189
+ fig, ax1 = plt.subplots(figsize=SEQUENCE_PLOT_SIZE, dpi=DPI_value)
190
190
 
191
191
  # Plot RMSE
192
192
  color_rmse = format_config.rmse_color
@@ -165,7 +165,7 @@ def segmentation_metrics(
165
165
  cm = confusion_matrix(y_true_flat, y_pred_flat, labels=labels)
166
166
 
167
167
  # Plot
168
- fig_cm, ax_cm = plt.subplots(figsize=(max(8, len(labels) * 0.8), max(8, len(labels) * 0.8)), dpi=100)
168
+ fig_cm, ax_cm = plt.subplots(figsize=(max(8, len(labels) * 0.8), max(8, len(labels) * 0.8)), dpi=DPI_value)
169
169
  disp = ConfusionMatrixDisplay(
170
170
  confusion_matrix=cm,
171
171
  display_labels=display_names
ml_tools/_core/_keys.py CHANGED
@@ -223,10 +223,19 @@ class SchemaKeys:
223
223
 
224
224
  class _EvaluationConfig:
225
225
  """Set config values for evaluation modules."""
226
- DPI = 250
227
- REGRESSION_PLOT_SIZE = (9, 6)
228
- SEQUENCE_PLOT_SIZE = (9, 6)
229
-
226
+ DPI = 400
227
+ LABEL_PADDING = 10
228
+ # large sizes for SVG layout to accommodate large fonts
229
+ REGRESSION_PLOT_SIZE = (12, 8)
230
+ SEQUENCE_PLOT_SIZE = (12, 8)
231
+ CLASSIFICATION_PLOT_SIZE = (10, 10)
232
+ # Loss plot
233
+ LOSS_PLOT_SIZE = (18, 9)
234
+ LOSS_PLOT_LABEL_SIZE = 24
235
+ LOSS_PLOT_TICK_SIZE = 22
236
+ LOSS_PLOT_LEGEND_SIZE = 24
237
+ # CM settings
238
+ CM_SIZE = (9, 8) # used for multi label binary classification confusion matrix
230
239
 
231
240
  class _OneHotOtherPlaceholder:
232
241
  """Used internally by GUI_tools."""