dragon-ml-toolbox 19.6.0__py3-none-any.whl → 19.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-19.6.0.dist-info → dragon_ml_toolbox-19.7.0.dist-info}/METADATA +1 -1
- {dragon_ml_toolbox-19.6.0.dist-info → dragon_ml_toolbox-19.7.0.dist-info}/RECORD +12 -12
- ml_tools/_core/_ML_configuration.py +121 -37
- ml_tools/_core/_ML_evaluation.py +177 -63
- ml_tools/_core/_ML_evaluation_multi.py +89 -32
- ml_tools/_core/_ML_sequence_evaluation.py +1 -1
- ml_tools/_core/_ML_vision_evaluation.py +1 -1
- ml_tools/_core/_keys.py +13 -4
- {dragon_ml_toolbox-19.6.0.dist-info → dragon_ml_toolbox-19.7.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-19.6.0.dist-info → dragon_ml_toolbox-19.7.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-19.6.0.dist-info → dragon_ml_toolbox-19.7.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-19.6.0.dist-info → dragon_ml_toolbox-19.7.0.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
dragon_ml_toolbox-19.
|
|
2
|
-
dragon_ml_toolbox-19.
|
|
1
|
+
dragon_ml_toolbox-19.7.0.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
|
|
2
|
+
dragon_ml_toolbox-19.7.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=XBLtvGjvBf-q93a5iylHj94Lm78UzInC-3Cii01jc6I,3127
|
|
3
3
|
ml_tools/ETL_cleaning.py,sha256=cKXyRFaaFs_beAGDnQM54xnML671kq-yJEGjHafW-20,351
|
|
4
4
|
ml_tools/ETL_engineering.py,sha256=cwh1FhtNdUHllUDvho-x3SIVj4KwG_rFQR6VYzWUg0U,898
|
|
5
5
|
ml_tools/GUI_tools.py,sha256=O89rG8WQv6GY1DiphQjIsPzXFCQID6te7q_Sgt1iTkQ,294
|
|
@@ -58,12 +58,12 @@ ml_tools/_core/_MICE_imputation.py,sha256=_juIymUnNDRWjSLepL8Ee_PncoShbxjR7YtqTt
|
|
|
58
58
|
ml_tools/_core/_ML_callbacks.py,sha256=qtCrVFHTq-nk4NIsAdwIkfkKwFXX6I-6PoCgqZELp70,16734
|
|
59
59
|
ml_tools/_core/_ML_chaining_inference.py,sha256=vXUPZzuQ2yKU71kkvUsE0xPo0hN-Yu6gfnL0JbXoRjI,7783
|
|
60
60
|
ml_tools/_core/_ML_chaining_utilities.py,sha256=nsYowgRbkIYuzRiHlqsM3tnC3c-8O73CY8DHUF14XL0,19248
|
|
61
|
-
ml_tools/_core/_ML_configuration.py,sha256=
|
|
61
|
+
ml_tools/_core/_ML_configuration.py,sha256=t_6p_slOPhmy04wJcQj6D_bJyZRaXlsIeWIiULaJnXc,48716
|
|
62
62
|
ml_tools/_core/_ML_configuration_pytab.py,sha256=C3e4iScqdRePVDoqnic6xXMOW7DNYqpgTCeaFDyMdL4,3286
|
|
63
63
|
ml_tools/_core/_ML_datasetmaster.py,sha256=yU1BMtzz6XumMWCetVACrRLk7WJQwmYhaQ-VAWu9Ots,32043
|
|
64
|
-
ml_tools/_core/_ML_evaluation.py,sha256=
|
|
64
|
+
ml_tools/_core/_ML_evaluation.py,sha256=bu8qlYzhWSC1B7wNfCC5TSF-oed-uP8EF7TV45VTiBM,37325
|
|
65
65
|
ml_tools/_core/_ML_evaluation_captum.py,sha256=a69jnghIzE9qppuw2vzTBMdTErnZkDkTA3MPUUYjsS4,19212
|
|
66
|
-
ml_tools/_core/_ML_evaluation_multi.py,sha256=
|
|
66
|
+
ml_tools/_core/_ML_evaluation_multi.py,sha256=n_AJbKF58DMUrYqJutwPFV5z6sNssDPA1Gl05IfPG5s,23647
|
|
67
67
|
ml_tools/_core/_ML_finalize_handler.py,sha256=0eZ_0N2L5aUUIJUgvhAQ-rbd8XbE9UmNqTKSJq09uTI,6987
|
|
68
68
|
ml_tools/_core/_ML_inference.py,sha256=5swm2lnsrDLalBnCm7gZPlDucX4yNCq5vn7ck3SW_4Q,29791
|
|
69
69
|
ml_tools/_core/_ML_models.py,sha256=8FUx4-TVghlBF9srh1_5UxovrWPU7YEZ6XXLqwJei88,27974
|
|
@@ -73,13 +73,13 @@ ml_tools/_core/_ML_optimization.py,sha256=b1qfHiGyvVoj-ENqDbHTf1jNx55niUWE9KEZJv
|
|
|
73
73
|
ml_tools/_core/_ML_optimization_pareto.py,sha256=7jjV7i-A_J8vizDKg2ZIWNMVRu5oJokRmDbIkhofdlk,34831
|
|
74
74
|
ml_tools/_core/_ML_scaler.py,sha256=Nhu6qli_QezHQi5NKhRb8Z51bBJgzk2nEp_yW4B9H4U,8134
|
|
75
75
|
ml_tools/_core/_ML_sequence_datasetmaster.py,sha256=0YVOPf-y4ZNdgUxropXUWrmInNyGYaUYprYvXf31n9U,17811
|
|
76
|
-
ml_tools/_core/_ML_sequence_evaluation.py,sha256=
|
|
76
|
+
ml_tools/_core/_ML_sequence_evaluation.py,sha256=AiPHtZ9DRpE6zL9n3Tp5eGGD9vrYRkLbZ0Nc274mL7I,8069
|
|
77
77
|
ml_tools/_core/_ML_sequence_inference.py,sha256=zd3hBwOtLmjAV4JtdB2qFY9GxhysajFufATdy8fjGTE,16316
|
|
78
78
|
ml_tools/_core/_ML_sequence_models.py,sha256=5qcEYLU6wDePBITnikBrj_H9mCvyJmElKa3HiWGXhZs,5639
|
|
79
79
|
ml_tools/_core/_ML_trainer.py,sha256=hSsudWrlYWpi53DXIlKI6ovVhz7xLrQ8oKIDJOXf4Eg,117747
|
|
80
80
|
ml_tools/_core/_ML_utilities.py,sha256=yXVKow-bgpahMChpp7iUlSxAEtgityXwC54FPReeNNA,30487
|
|
81
81
|
ml_tools/_core/_ML_vision_datasetmaster.py,sha256=8EsE7luzphVlwBXdOsOwsFfz1D4UIUSEQtqHlM0Vf-o,67084
|
|
82
|
-
ml_tools/_core/_ML_vision_evaluation.py,sha256=
|
|
82
|
+
ml_tools/_core/_ML_vision_evaluation.py,sha256=BSLf9xrGpaR02Dhkf-fAbgxSpwRjf7DruNIcQadl7qg,11631
|
|
83
83
|
ml_tools/_core/_ML_vision_inference.py,sha256=6K9gMFjAAZKfLAIQlOkm_I9hvCPmO--9-1vnskQRk0I,20190
|
|
84
84
|
ml_tools/_core/_ML_vision_models.py,sha256=oUik-RLxFvZFZCtFztjkSfFYgJuRx4QzfwHVY1ny4Sc,26217
|
|
85
85
|
ml_tools/_core/_ML_vision_transformers.py,sha256=imjL9h5kwpfuRn9rBelNpgtrdU-EecBEcHMFZMXTeZA,15303
|
|
@@ -92,7 +92,7 @@ ml_tools/_core/_ensemble_evaluation.py,sha256=17lWl4bWLT1BAMv_fhGf2D3wy-F4jx0Hgn
|
|
|
92
92
|
ml_tools/_core/_ensemble_inference.py,sha256=PfZG-r65Vw3IAmBJZg9W0zYGEe-QbhfUh_rd2ho-rr8,8610
|
|
93
93
|
ml_tools/_core/_ensemble_learning.py,sha256=X8ghbjDOLMENCWdISXLhDlHQtR3C6SW1tkTBAcfRRPY,22016
|
|
94
94
|
ml_tools/_core/_excel_handler.py,sha256=gV4rSIsiowb0xllpEJxzUKaYDDVpmP_lxs9wZA76-cc,14050
|
|
95
|
-
ml_tools/_core/_keys.py,sha256=
|
|
95
|
+
ml_tools/_core/_keys.py,sha256=4RE-ZuCJkUmqefz-dc3qrVbftqVAWkunZFrP2yJjpCU,6740
|
|
96
96
|
ml_tools/_core/_logger.py,sha256=86Ge0sDE_WgwsZBglQRYPyFYX3lcsIo0NzszNPzlxuk,5254
|
|
97
97
|
ml_tools/_core/_math_utilities.py,sha256=IlXAiZgTcLtus03jJOBOyF9ZCQDf8qLGjrCHu9Mrgak,9091
|
|
98
98
|
ml_tools/_core/_models_advanced_base.py,sha256=ceW0V_CcfOnSFqHlxUhVU8-5mtQq4tFyo8TX-xVexrY,4982
|
|
@@ -104,7 +104,7 @@ ml_tools/_core/_schema.py,sha256=TM5WVVMoKOvr_Bc2z34sU_gzKlM465PRKTgdZaEOkGY,140
|
|
|
104
104
|
ml_tools/_core/_script_info.py,sha256=21r83LV3RubsNZ_RTEUON6RbDf7Mh4_udweNcvdF_Fk,212
|
|
105
105
|
ml_tools/_core/_serde.py,sha256=tsI4EO2Y7jrBMmbQ1pinDsPOrOg-SaPuB-Dt40q0taE,5609
|
|
106
106
|
ml_tools/_core/_utilities.py,sha256=iA8fLWdhsIx4ut2Dp8M_OyU0Y3PPLgGdIklyl17x6xk,22560
|
|
107
|
-
dragon_ml_toolbox-19.
|
|
108
|
-
dragon_ml_toolbox-19.
|
|
109
|
-
dragon_ml_toolbox-19.
|
|
110
|
-
dragon_ml_toolbox-19.
|
|
107
|
+
dragon_ml_toolbox-19.7.0.dist-info/METADATA,sha256=64wv9eyG4FCm-QRo-9S0wSwARsYo6I-ULnRli6t6UxU,8764
|
|
108
|
+
dragon_ml_toolbox-19.7.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
109
|
+
dragon_ml_toolbox-19.7.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
|
|
110
|
+
dragon_ml_toolbox-19.7.0.dist-info/RECORD,,
|
|
@@ -65,7 +65,11 @@ class _BaseClassificationFormat:
|
|
|
65
65
|
cmap: str="BuGn",
|
|
66
66
|
ROC_PR_line: str='darkorange',
|
|
67
67
|
calibration_bins: int=15,
|
|
68
|
-
|
|
68
|
+
xtick_size: int=22,
|
|
69
|
+
ytick_size: int=22,
|
|
70
|
+
legend_size: int=26,
|
|
71
|
+
font_size: int=26,
|
|
72
|
+
cm_font_size: int=26) -> None:
|
|
69
73
|
"""
|
|
70
74
|
Initializes the formatting configuration for single-label classification metrics.
|
|
71
75
|
|
|
@@ -84,6 +88,14 @@ class _BaseClassificationFormat:
|
|
|
84
88
|
creating the calibration (reliability) plot.
|
|
85
89
|
|
|
86
90
|
font_size (int): The base font size to apply to the plots.
|
|
91
|
+
|
|
92
|
+
xtick_size (int): Font size for x-axis tick labels.
|
|
93
|
+
|
|
94
|
+
ytick_size (int): Font size for y-axis tick labels.
|
|
95
|
+
|
|
96
|
+
legend_size (int): Font size for plot legends.
|
|
97
|
+
|
|
98
|
+
cm_font_size (int): Font size for the confusion matrix.
|
|
87
99
|
|
|
88
100
|
<br>
|
|
89
101
|
|
|
@@ -97,13 +109,21 @@ class _BaseClassificationFormat:
|
|
|
97
109
|
self.ROC_PR_line = ROC_PR_line
|
|
98
110
|
self.calibration_bins = calibration_bins
|
|
99
111
|
self.font_size = font_size
|
|
112
|
+
self.xtick_size = xtick_size
|
|
113
|
+
self.ytick_size = ytick_size
|
|
114
|
+
self.legend_size = legend_size
|
|
115
|
+
self.cm_font_size = cm_font_size
|
|
100
116
|
|
|
101
117
|
def __repr__(self) -> str:
|
|
102
118
|
parts = [
|
|
103
119
|
f"cmap='{self.cmap}'",
|
|
104
120
|
f"ROC_PR_line='{self.ROC_PR_line}'",
|
|
105
121
|
f"calibration_bins={self.calibration_bins}",
|
|
106
|
-
f"font_size={self.font_size}"
|
|
122
|
+
f"font_size={self.font_size}",
|
|
123
|
+
f"xtick_size={self.xtick_size}",
|
|
124
|
+
f"ytick_size={self.ytick_size}",
|
|
125
|
+
f"legend_size={self.legend_size}",
|
|
126
|
+
f"cm_font_size={self.cm_font_size}"
|
|
107
127
|
]
|
|
108
128
|
return f"{self.__class__.__name__}({', '.join(parts)})"
|
|
109
129
|
|
|
@@ -115,7 +135,10 @@ class _BaseMultiLabelFormat:
|
|
|
115
135
|
def __init__(self,
|
|
116
136
|
cmap: str = "BuGn",
|
|
117
137
|
ROC_PR_line: str='darkorange',
|
|
118
|
-
font_size: int =
|
|
138
|
+
font_size: int = 25,
|
|
139
|
+
xtick_size: int=20,
|
|
140
|
+
ytick_size: int=20,
|
|
141
|
+
legend_size: int=23) -> None:
|
|
119
142
|
"""
|
|
120
143
|
Initializes the formatting configuration for multi-label classification metrics.
|
|
121
144
|
|
|
@@ -132,6 +155,12 @@ class _BaseMultiLabelFormat:
|
|
|
132
155
|
|
|
133
156
|
font_size (int): The base font size to apply to the plots.
|
|
134
157
|
|
|
158
|
+
xtick_size (int): Font size for x-axis tick labels.
|
|
159
|
+
|
|
160
|
+
ytick_size (int): Font size for y-axis tick labels.
|
|
161
|
+
|
|
162
|
+
legend_size (int): Font size for plot legends.
|
|
163
|
+
|
|
135
164
|
<br>
|
|
136
165
|
|
|
137
166
|
### [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
|
|
@@ -143,12 +172,18 @@ class _BaseMultiLabelFormat:
|
|
|
143
172
|
self.cmap = cmap
|
|
144
173
|
self.ROC_PR_line = ROC_PR_line
|
|
145
174
|
self.font_size = font_size
|
|
175
|
+
self.xtick_size = xtick_size
|
|
176
|
+
self.ytick_size = ytick_size
|
|
177
|
+
self.legend_size = legend_size
|
|
146
178
|
|
|
147
179
|
def __repr__(self) -> str:
|
|
148
180
|
parts = [
|
|
149
181
|
f"cmap='{self.cmap}'",
|
|
150
182
|
f"ROC_PR_line='{self.ROC_PR_line}'",
|
|
151
|
-
f"font_size={self.font_size}"
|
|
183
|
+
f"font_size={self.font_size}",
|
|
184
|
+
f"xtick_size={self.xtick_size}",
|
|
185
|
+
f"ytick_size={self.ytick_size}",
|
|
186
|
+
f"legend_size={self.legend_size}"
|
|
152
187
|
]
|
|
153
188
|
return f"{self.__class__.__name__}({', '.join(parts)})"
|
|
154
189
|
|
|
@@ -158,12 +193,14 @@ class _BaseRegressionFormat:
|
|
|
158
193
|
[PRIVATE] Base configuration for regression metrics.
|
|
159
194
|
"""
|
|
160
195
|
def __init__(self,
|
|
161
|
-
font_size: int=
|
|
196
|
+
font_size: int=26,
|
|
162
197
|
scatter_color: str='tab:blue',
|
|
163
198
|
scatter_alpha: float=0.6,
|
|
164
199
|
ideal_line_color: str='k',
|
|
165
200
|
residual_line_color: str='red',
|
|
166
|
-
hist_bins: Union[int, str] = 'auto'
|
|
201
|
+
hist_bins: Union[int, str] = 'auto',
|
|
202
|
+
xtick_size: int=22,
|
|
203
|
+
ytick_size: int=22) -> None:
|
|
167
204
|
"""
|
|
168
205
|
Initializes the formatting configuration for regression metrics.
|
|
169
206
|
|
|
@@ -181,6 +218,8 @@ class _BaseRegressionFormat:
|
|
|
181
218
|
hist_bins (int | str): The number of bins for the residuals histogram.
|
|
182
219
|
Defaults to 'auto' to use seaborn's automatic bin selection.
|
|
183
220
|
- Options: 'auto', 'sqrt', 10, 20
|
|
221
|
+
xtick_size (int): Font size for x-axis tick labels.
|
|
222
|
+
ytick_size (int): Font size for y-axis tick labels.
|
|
184
223
|
|
|
185
224
|
<br>
|
|
186
225
|
|
|
@@ -192,6 +231,8 @@ class _BaseRegressionFormat:
|
|
|
192
231
|
self.ideal_line_color = ideal_line_color
|
|
193
232
|
self.residual_line_color = residual_line_color
|
|
194
233
|
self.hist_bins = hist_bins
|
|
234
|
+
self.xtick_size = xtick_size
|
|
235
|
+
self.ytick_size = ytick_size
|
|
195
236
|
|
|
196
237
|
def __repr__(self) -> str:
|
|
197
238
|
parts = [
|
|
@@ -200,7 +241,9 @@ class _BaseRegressionFormat:
|
|
|
200
241
|
f"scatter_alpha={self.scatter_alpha}",
|
|
201
242
|
f"ideal_line_color='{self.ideal_line_color}'",
|
|
202
243
|
f"residual_line_color='{self.residual_line_color}'",
|
|
203
|
-
f"hist_bins='{self.hist_bins}'"
|
|
244
|
+
f"hist_bins='{self.hist_bins}'",
|
|
245
|
+
f"xtick_size={self.xtick_size}",
|
|
246
|
+
f"ytick_size={self.ytick_size}"
|
|
204
247
|
]
|
|
205
248
|
return f"{self.__class__.__name__}({', '.join(parts)})"
|
|
206
249
|
|
|
@@ -248,7 +291,7 @@ class _BaseSequenceValueFormat:
|
|
|
248
291
|
[PRIVATE] Base configuration for sequence to value metrics.
|
|
249
292
|
"""
|
|
250
293
|
def __init__(self,
|
|
251
|
-
font_size: int=
|
|
294
|
+
font_size: int=25,
|
|
252
295
|
scatter_color: str='tab:blue',
|
|
253
296
|
scatter_alpha: float=0.6,
|
|
254
297
|
ideal_line_color: str='k',
|
|
@@ -300,8 +343,7 @@ class _BaseSequenceSequenceFormat:
|
|
|
300
343
|
[PRIVATE] Base configuration for sequence-to-sequence metrics.
|
|
301
344
|
"""
|
|
302
345
|
def __init__(self,
|
|
303
|
-
font_size: int =
|
|
304
|
-
plot_figsize: tuple[int, int] = (10, 6),
|
|
346
|
+
font_size: int = 25,
|
|
305
347
|
grid_style: str = '--',
|
|
306
348
|
rmse_color: str = 'tab:blue',
|
|
307
349
|
rmse_marker: str = 'o-',
|
|
@@ -312,7 +354,6 @@ class _BaseSequenceSequenceFormat:
|
|
|
312
354
|
|
|
313
355
|
Args:
|
|
314
356
|
font_size (int): The base font size to apply to the plots.
|
|
315
|
-
plot_figsize (Tuple[int, int]): Figure size for the plot.
|
|
316
357
|
grid_style (str): Matplotlib linestyle for the plot grid.
|
|
317
358
|
- Options: '--' (dashed), ':' (dotted), '-.' (dash-dot), '-' (solid)
|
|
318
359
|
rmse_color (str): Matplotlib color for the RMSE line.
|
|
@@ -337,7 +378,6 @@ class _BaseSequenceSequenceFormat:
|
|
|
337
378
|
### [Matplotlib Markers](https://matplotlib.org/stable/api/markers_api.html)
|
|
338
379
|
"""
|
|
339
380
|
self.font_size = font_size
|
|
340
|
-
self.plot_figsize = plot_figsize
|
|
341
381
|
self.grid_style = grid_style
|
|
342
382
|
self.rmse_color = rmse_color
|
|
343
383
|
self.rmse_marker = rmse_marker
|
|
@@ -347,7 +387,6 @@ class _BaseSequenceSequenceFormat:
|
|
|
347
387
|
def __repr__(self) -> str:
|
|
348
388
|
parts = [
|
|
349
389
|
f"font_size={self.font_size}",
|
|
350
|
-
f"plot_figsize={self.plot_figsize}",
|
|
351
390
|
f"grid_style='{self.grid_style}'",
|
|
352
391
|
f"rmse_color='{self.rmse_color}'",
|
|
353
392
|
f"mae_color='{self.mae_color}'"
|
|
@@ -639,18 +678,22 @@ class RegressionMetricsFormat(_BaseRegressionFormat):
|
|
|
639
678
|
Configuration for single-target regression.
|
|
640
679
|
"""
|
|
641
680
|
def __init__(self,
|
|
642
|
-
font_size: int=
|
|
681
|
+
font_size: int=26,
|
|
643
682
|
scatter_color: str='tab:blue',
|
|
644
683
|
scatter_alpha: float=0.6,
|
|
645
684
|
ideal_line_color: str='k',
|
|
646
685
|
residual_line_color: str='red',
|
|
647
|
-
hist_bins: Union[int, str] = 'auto'
|
|
686
|
+
hist_bins: Union[int, str] = 'auto',
|
|
687
|
+
xtick_size: int=22,
|
|
688
|
+
ytick_size: int=22) -> None:
|
|
648
689
|
super().__init__(font_size=font_size,
|
|
649
690
|
scatter_color=scatter_color,
|
|
650
691
|
scatter_alpha=scatter_alpha,
|
|
651
692
|
ideal_line_color=ideal_line_color,
|
|
652
693
|
residual_line_color=residual_line_color,
|
|
653
|
-
hist_bins=hist_bins
|
|
694
|
+
hist_bins=hist_bins,
|
|
695
|
+
xtick_size=xtick_size,
|
|
696
|
+
ytick_size=ytick_size)
|
|
654
697
|
|
|
655
698
|
|
|
656
699
|
# Multitarget regression
|
|
@@ -659,18 +702,22 @@ class MultiTargetRegressionMetricsFormat(_BaseRegressionFormat):
|
|
|
659
702
|
Configuration for multi-target regression.
|
|
660
703
|
"""
|
|
661
704
|
def __init__(self,
|
|
662
|
-
font_size: int=
|
|
705
|
+
font_size: int=26,
|
|
663
706
|
scatter_color: str='tab:blue',
|
|
664
707
|
scatter_alpha: float=0.6,
|
|
665
708
|
ideal_line_color: str='k',
|
|
666
709
|
residual_line_color: str='red',
|
|
667
|
-
hist_bins: Union[int, str] = 'auto'
|
|
710
|
+
hist_bins: Union[int, str] = 'auto',
|
|
711
|
+
xtick_size: int=22,
|
|
712
|
+
ytick_size: int=22) -> None:
|
|
668
713
|
super().__init__(font_size=font_size,
|
|
669
714
|
scatter_color=scatter_color,
|
|
670
715
|
scatter_alpha=scatter_alpha,
|
|
671
716
|
ideal_line_color=ideal_line_color,
|
|
672
717
|
residual_line_color=residual_line_color,
|
|
673
|
-
hist_bins=hist_bins
|
|
718
|
+
hist_bins=hist_bins,
|
|
719
|
+
xtick_size=xtick_size,
|
|
720
|
+
ytick_size=ytick_size)
|
|
674
721
|
|
|
675
722
|
|
|
676
723
|
# Classification
|
|
@@ -682,11 +729,20 @@ class BinaryClassificationMetricsFormat(_BaseClassificationFormat):
|
|
|
682
729
|
cmap: str="BuGn",
|
|
683
730
|
ROC_PR_line: str='darkorange',
|
|
684
731
|
calibration_bins: int=15,
|
|
685
|
-
font_size: int=
|
|
732
|
+
font_size: int=26,
|
|
733
|
+
xtick_size: int=22,
|
|
734
|
+
ytick_size: int=22,
|
|
735
|
+
legend_size: int=26,
|
|
736
|
+
cm_font_size: int=26
|
|
737
|
+
) -> None:
|
|
686
738
|
super().__init__(cmap=cmap,
|
|
687
739
|
ROC_PR_line=ROC_PR_line,
|
|
688
740
|
calibration_bins=calibration_bins,
|
|
689
|
-
font_size=font_size
|
|
741
|
+
font_size=font_size,
|
|
742
|
+
xtick_size=xtick_size,
|
|
743
|
+
ytick_size=ytick_size,
|
|
744
|
+
legend_size=legend_size,
|
|
745
|
+
cm_font_size=cm_font_size)
|
|
690
746
|
|
|
691
747
|
|
|
692
748
|
class MultiClassClassificationMetricsFormat(_BaseClassificationFormat):
|
|
@@ -697,12 +753,20 @@ class MultiClassClassificationMetricsFormat(_BaseClassificationFormat):
|
|
|
697
753
|
cmap: str="BuGn",
|
|
698
754
|
ROC_PR_line: str='darkorange',
|
|
699
755
|
calibration_bins: int=15,
|
|
700
|
-
font_size: int=
|
|
756
|
+
font_size: int=26,
|
|
757
|
+
xtick_size: int=22,
|
|
758
|
+
ytick_size: int=22,
|
|
759
|
+
legend_size: int=26,
|
|
760
|
+
cm_font_size: int=26
|
|
761
|
+
) -> None:
|
|
701
762
|
super().__init__(cmap=cmap,
|
|
702
763
|
ROC_PR_line=ROC_PR_line,
|
|
703
764
|
calibration_bins=calibration_bins,
|
|
704
|
-
font_size=font_size
|
|
705
|
-
|
|
765
|
+
font_size=font_size,
|
|
766
|
+
xtick_size=xtick_size,
|
|
767
|
+
ytick_size=ytick_size,
|
|
768
|
+
legend_size=legend_size,
|
|
769
|
+
cm_font_size=cm_font_size)
|
|
706
770
|
|
|
707
771
|
class BinaryImageClassificationMetricsFormat(_BaseClassificationFormat):
|
|
708
772
|
"""
|
|
@@ -712,12 +776,20 @@ class BinaryImageClassificationMetricsFormat(_BaseClassificationFormat):
|
|
|
712
776
|
cmap: str="BuGn",
|
|
713
777
|
ROC_PR_line: str='darkorange',
|
|
714
778
|
calibration_bins: int=15,
|
|
715
|
-
font_size: int=
|
|
779
|
+
font_size: int=26,
|
|
780
|
+
xtick_size: int=22,
|
|
781
|
+
ytick_size: int=22,
|
|
782
|
+
legend_size: int=26,
|
|
783
|
+
cm_font_size: int=26
|
|
784
|
+
) -> None:
|
|
716
785
|
super().__init__(cmap=cmap,
|
|
717
786
|
ROC_PR_line=ROC_PR_line,
|
|
718
787
|
calibration_bins=calibration_bins,
|
|
719
|
-
font_size=font_size
|
|
720
|
-
|
|
788
|
+
font_size=font_size,
|
|
789
|
+
xtick_size=xtick_size,
|
|
790
|
+
ytick_size=ytick_size,
|
|
791
|
+
legend_size=legend_size,
|
|
792
|
+
cm_font_size=cm_font_size)
|
|
721
793
|
|
|
722
794
|
class MultiClassImageClassificationMetricsFormat(_BaseClassificationFormat):
|
|
723
795
|
"""
|
|
@@ -727,12 +799,20 @@ class MultiClassImageClassificationMetricsFormat(_BaseClassificationFormat):
|
|
|
727
799
|
cmap: str="BuGn",
|
|
728
800
|
ROC_PR_line: str='darkorange',
|
|
729
801
|
calibration_bins: int=15,
|
|
730
|
-
font_size: int=
|
|
802
|
+
font_size: int=26,
|
|
803
|
+
xtick_size: int=22,
|
|
804
|
+
ytick_size: int=22,
|
|
805
|
+
legend_size: int=26,
|
|
806
|
+
cm_font_size: int=26
|
|
807
|
+
) -> None:
|
|
731
808
|
super().__init__(cmap=cmap,
|
|
732
809
|
ROC_PR_line=ROC_PR_line,
|
|
733
810
|
calibration_bins=calibration_bins,
|
|
734
|
-
font_size=font_size
|
|
735
|
-
|
|
811
|
+
font_size=font_size,
|
|
812
|
+
xtick_size=xtick_size,
|
|
813
|
+
ytick_size=ytick_size,
|
|
814
|
+
legend_size=legend_size,
|
|
815
|
+
cm_font_size=cm_font_size)
|
|
736
816
|
|
|
737
817
|
# Multi-Label classification
|
|
738
818
|
class MultiLabelBinaryClassificationMetricsFormat(_BaseMultiLabelFormat):
|
|
@@ -742,11 +822,17 @@ class MultiLabelBinaryClassificationMetricsFormat(_BaseMultiLabelFormat):
|
|
|
742
822
|
def __init__(self,
|
|
743
823
|
cmap: str = "BuGn",
|
|
744
824
|
ROC_PR_line: str='darkorange',
|
|
745
|
-
font_size: int =
|
|
825
|
+
font_size: int = 25,
|
|
826
|
+
xtick_size: int=20,
|
|
827
|
+
ytick_size: int=20,
|
|
828
|
+
legend_size: int=23
|
|
829
|
+
) -> None:
|
|
746
830
|
super().__init__(cmap=cmap,
|
|
747
831
|
ROC_PR_line=ROC_PR_line,
|
|
748
|
-
font_size=font_size
|
|
749
|
-
|
|
832
|
+
font_size=font_size,
|
|
833
|
+
xtick_size=xtick_size,
|
|
834
|
+
ytick_size=ytick_size,
|
|
835
|
+
legend_size=legend_size)
|
|
750
836
|
|
|
751
837
|
# Segmentation
|
|
752
838
|
class BinarySegmentationMetricsFormat(_BaseSegmentationFormat):
|
|
@@ -781,7 +867,7 @@ class SequenceValueMetricsFormat(_BaseSequenceValueFormat):
|
|
|
781
867
|
Configuration for sequence-to-value prediction.
|
|
782
868
|
"""
|
|
783
869
|
def __init__(self,
|
|
784
|
-
font_size: int=
|
|
870
|
+
font_size: int=25,
|
|
785
871
|
scatter_color: str='tab:blue',
|
|
786
872
|
scatter_alpha: float=0.6,
|
|
787
873
|
ideal_line_color: str='k',
|
|
@@ -800,15 +886,13 @@ class SequenceSequenceMetricsFormat(_BaseSequenceSequenceFormat):
|
|
|
800
886
|
Configuration for sequence-to-sequence prediction.
|
|
801
887
|
"""
|
|
802
888
|
def __init__(self,
|
|
803
|
-
font_size: int =
|
|
804
|
-
plot_figsize: tuple[int, int] = (10, 6),
|
|
889
|
+
font_size: int = 25,
|
|
805
890
|
grid_style: str = '--',
|
|
806
891
|
rmse_color: str = 'tab:blue',
|
|
807
892
|
rmse_marker: str = 'o-',
|
|
808
893
|
mae_color: str = 'tab:orange',
|
|
809
894
|
mae_marker: str = 's--'):
|
|
810
895
|
super().__init__(font_size=font_size,
|
|
811
|
-
plot_figsize=plot_figsize,
|
|
812
896
|
grid_style=grid_style,
|
|
813
897
|
rmse_color=rmse_color,
|
|
814
898
|
rmse_marker=rmse_marker,
|
ml_tools/_core/_ML_evaluation.py
CHANGED
|
@@ -48,6 +48,7 @@ __all__ = [
|
|
|
48
48
|
|
|
49
49
|
DPI_value = _EvaluationConfig.DPI
|
|
50
50
|
REGRESSION_PLOT_SIZE = _EvaluationConfig.REGRESSION_PLOT_SIZE
|
|
51
|
+
CLASSIFICATION_PLOT_SIZE = _EvaluationConfig.CLASSIFICATION_PLOT_SIZE
|
|
51
52
|
|
|
52
53
|
|
|
53
54
|
def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
@@ -67,7 +68,7 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
|
67
68
|
_LOGGER.warning("Loss history is empty or incomplete. Cannot plot.")
|
|
68
69
|
return
|
|
69
70
|
|
|
70
|
-
fig, ax = plt.subplots(figsize=
|
|
71
|
+
fig, ax = plt.subplots(figsize=_EvaluationConfig.LOSS_PLOT_SIZE, dpi=DPI_value)
|
|
71
72
|
|
|
72
73
|
# --- Plot Losses (Left Y-axis) ---
|
|
73
74
|
line_handles = [] # To store line objects for the legend
|
|
@@ -84,10 +85,11 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
|
84
85
|
line2, = ax.plot(epochs, val_loss, 'o-', label='Validation Loss', color='tab:orange')
|
|
85
86
|
line_handles.append(line2)
|
|
86
87
|
|
|
87
|
-
ax.set_title('Training and Validation Loss')
|
|
88
|
-
ax.set_xlabel('Epochs')
|
|
89
|
-
ax.set_ylabel('Loss', color='tab:blue')
|
|
90
|
-
ax.tick_params(axis='y', labelcolor='tab:blue')
|
|
88
|
+
ax.set_title('Training and Validation Loss', fontsize=_EvaluationConfig.LOSS_PLOT_LABEL_SIZE + 2, pad=_EvaluationConfig.LABEL_PADDING)
|
|
89
|
+
ax.set_xlabel('Epochs', fontsize=_EvaluationConfig.LOSS_PLOT_LABEL_SIZE, labelpad=_EvaluationConfig.LABEL_PADDING)
|
|
90
|
+
ax.set_ylabel('Loss', color='tab:blue', fontsize=_EvaluationConfig.LOSS_PLOT_LABEL_SIZE, labelpad=_EvaluationConfig.LABEL_PADDING)
|
|
91
|
+
ax.tick_params(axis='y', labelcolor='tab:blue', labelsize=_EvaluationConfig.LOSS_PLOT_TICK_SIZE)
|
|
92
|
+
ax.tick_params(axis='x', labelsize=_EvaluationConfig.LOSS_PLOT_TICK_SIZE)
|
|
91
93
|
ax.grid(True, linestyle='--')
|
|
92
94
|
|
|
93
95
|
# --- Plot Learning Rate (Right Y-axis) ---
|
|
@@ -97,13 +99,17 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
|
97
99
|
line3, = ax2.plot(epochs, lr_history, 'g--', label='Learning Rate')
|
|
98
100
|
line_handles.append(line3)
|
|
99
101
|
|
|
100
|
-
ax2.set_ylabel('Learning Rate', color='g')
|
|
101
|
-
ax2.tick_params(axis='y', labelcolor='g')
|
|
102
|
+
ax2.set_ylabel('Learning Rate', color='g', fontsize=_EvaluationConfig.LOSS_PLOT_LABEL_SIZE, labelpad=_EvaluationConfig.LABEL_PADDING)
|
|
103
|
+
ax2.tick_params(axis='y', labelcolor='g', labelsize=_EvaluationConfig.LOSS_PLOT_TICK_SIZE)
|
|
102
104
|
# Use scientific notation if the LR is very small
|
|
103
105
|
ax2.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
|
|
106
|
+
# increase the size of the scientific notation
|
|
107
|
+
ax2.yaxis.get_offset_text().set_fontsize(_EvaluationConfig.LOSS_PLOT_TICK_SIZE - 2)
|
|
108
|
+
# remove grid from second y-axis
|
|
109
|
+
ax2.grid(False)
|
|
104
110
|
|
|
105
111
|
# Combine legends from both axes
|
|
106
|
-
ax.legend(handles=line_handles, loc='best')
|
|
112
|
+
ax.legend(handles=line_handles, loc='best', fontsize=_EvaluationConfig.LOSS_PLOT_LEGEND_SIZE)
|
|
107
113
|
|
|
108
114
|
# ax.grid(True)
|
|
109
115
|
plt.tight_layout()
|
|
@@ -142,10 +148,17 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
142
148
|
else:
|
|
143
149
|
format_config = config
|
|
144
150
|
|
|
145
|
-
original_rc_params = plt.rcParams.copy()
|
|
146
|
-
plt.rcParams.update({'font.size': format_config.font_size})
|
|
151
|
+
# original_rc_params = plt.rcParams.copy()
|
|
152
|
+
# plt.rcParams.update({'font.size': format_config.font_size})
|
|
147
153
|
|
|
148
|
-
#
|
|
154
|
+
# --- Set Font Sizes ---
|
|
155
|
+
xtick_size = format_config.xtick_size
|
|
156
|
+
ytick_size = format_config.ytick_size
|
|
157
|
+
legend_size = format_config.legend_size
|
|
158
|
+
|
|
159
|
+
# config font size for heatmap
|
|
160
|
+
cm_font_size = format_config.cm_font_size
|
|
161
|
+
cm_tick_size = cm_font_size - 4
|
|
149
162
|
|
|
150
163
|
# --- Parse class_map ---
|
|
151
164
|
map_labels = None
|
|
@@ -176,61 +189,122 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
176
189
|
try:
|
|
177
190
|
# Create DataFrame from report
|
|
178
191
|
report_df = pd.DataFrame(report_dict)
|
|
179
|
-
|
|
180
|
-
# 1. Drop the 'accuracy' column (single float)
|
|
181
|
-
if 'accuracy' in report_df.columns:
|
|
182
|
-
report_df = report_df.drop(columns=['accuracy'])
|
|
183
|
-
|
|
184
|
-
# 2. Select all metric rows *except* the last one ('support')
|
|
185
|
-
# 3. Transpose the DataFrame
|
|
186
|
-
plot_df = report_df.iloc[:-1, :].T
|
|
187
|
-
|
|
188
|
-
fig_height = max(5.0, len(plot_df.index) * 0.5 + 2.0)
|
|
189
|
-
plt.figure(figsize=(7, fig_height), dpi=DPI_value)
|
|
190
192
|
|
|
191
|
-
|
|
193
|
+
# 1. Robust Cleanup: Drop by name, not position
|
|
194
|
+
# Remove 'accuracy' column if it exists (handles the scalar value issue)
|
|
195
|
+
report_df = report_df.drop(columns=['accuracy'], errors='ignore')
|
|
196
|
+
|
|
197
|
+
# Remove 'support' row explicitly (safer than iloc[:-1])
|
|
198
|
+
if 'support' in report_df.index:
|
|
199
|
+
report_df = report_df.drop(index='support')
|
|
200
|
+
|
|
201
|
+
# 2. Transpose: Rows = Classes, Cols = Metrics
|
|
202
|
+
plot_df = report_df.T
|
|
203
|
+
|
|
204
|
+
# 3. Dynamic Height Calculation
|
|
205
|
+
# (Base height of 4 + 0.5 inches per class row)
|
|
206
|
+
fig_height = max(5.0, len(plot_df.index) * 0.5 + 4.0)
|
|
207
|
+
fig_width = 8.0 # Set a fixed width
|
|
208
|
+
|
|
209
|
+
# --- Use calculated dimensions, not the config constant ---
|
|
210
|
+
fig_heat, ax_heat = plt.subplots(figsize=(fig_width, fig_height), dpi=_EvaluationConfig.DPI)
|
|
211
|
+
|
|
212
|
+
# sns.set_theme(font_scale=1.4)
|
|
192
213
|
sns.heatmap(plot_df,
|
|
193
214
|
annot=True,
|
|
194
215
|
cmap=format_config.cmap,
|
|
195
216
|
fmt='.2f',
|
|
196
217
|
vmin=0.0,
|
|
197
|
-
vmax=1.0
|
|
198
|
-
|
|
199
|
-
|
|
218
|
+
vmax=1.0,
|
|
219
|
+
cbar_kws={'shrink': 0.9}) # Shrink colorbar slightly to fit better
|
|
220
|
+
|
|
221
|
+
# sns.set_theme(font_scale=1.0)
|
|
222
|
+
|
|
223
|
+
ax_heat.set_title("Classification Report Heatmap", pad=_EvaluationConfig.LABEL_PADDING, fontsize=cm_font_size)
|
|
224
|
+
|
|
225
|
+
# manually increase the font size of the elements
|
|
226
|
+
for text in ax_heat.texts:
|
|
227
|
+
text.set_fontsize(cm_tick_size)
|
|
228
|
+
|
|
229
|
+
# manually increase the size of the colorbar ticks
|
|
230
|
+
cbar = ax_heat.collections[0].colorbar
|
|
231
|
+
cbar.ax.tick_params(labelsize=cm_tick_size - 4) # type: ignore
|
|
232
|
+
|
|
233
|
+
# Update Ticks
|
|
234
|
+
ax_heat.tick_params(axis='x', labelsize=cm_tick_size, pad=_EvaluationConfig.LABEL_PADDING)
|
|
235
|
+
ax_heat.tick_params(axis='y', labelsize=cm_tick_size, pad=_EvaluationConfig.LABEL_PADDING, rotation=0) # Ensure Y labels are horizontal
|
|
236
|
+
|
|
200
237
|
plt.tight_layout()
|
|
238
|
+
|
|
201
239
|
heatmap_path = save_dir_path / "classification_report_heatmap.svg"
|
|
202
240
|
plt.savefig(heatmap_path)
|
|
203
241
|
_LOGGER.info(f"📊 Report heatmap saved as '{heatmap_path.name}'")
|
|
204
|
-
plt.close()
|
|
242
|
+
plt.close(fig_heat)
|
|
243
|
+
|
|
205
244
|
except Exception as e:
|
|
206
245
|
_LOGGER.error(f"Could not generate classification report heatmap: {e}")
|
|
207
|
-
|
|
246
|
+
|
|
208
247
|
# --- labels for Confusion Matrix ---
|
|
209
248
|
plot_labels = map_labels
|
|
210
249
|
plot_display_labels = map_display_labels
|
|
211
250
|
|
|
212
|
-
#
|
|
213
|
-
|
|
251
|
+
# 1. DYNAMIC SIZE CALCULATION
|
|
252
|
+
# Calculate figure size based on number of classes.
|
|
253
|
+
n_classes = len(plot_labels) if plot_labels is not None else len(np.unique(y_true))
|
|
254
|
+
# Ensure a minimum size so very small matrices aren't tiny
|
|
255
|
+
fig_w = max(9, n_classes * 0.8 + 3)
|
|
256
|
+
fig_h = max(8, n_classes * 0.8 + 2)
|
|
257
|
+
|
|
258
|
+
# Use the calculated size instead of CLASSIFICATION_PLOT_SIZE
|
|
259
|
+
fig_cm, ax_cm = plt.subplots(figsize=(fig_w, fig_h), dpi=DPI_value)
|
|
214
260
|
disp_ = ConfusionMatrixDisplay.from_predictions(y_true,
|
|
215
261
|
y_pred,
|
|
216
262
|
cmap=format_config.cmap,
|
|
217
263
|
ax=ax_cm,
|
|
218
264
|
normalize='true',
|
|
219
265
|
labels=plot_labels,
|
|
220
|
-
display_labels=plot_display_labels
|
|
266
|
+
display_labels=plot_display_labels,
|
|
267
|
+
colorbar=False)
|
|
221
268
|
|
|
222
269
|
disp_.im_.set_clim(vmin=0.0, vmax=1.0)
|
|
223
270
|
|
|
224
271
|
# Turn off gridlines
|
|
225
272
|
ax_cm.grid(False)
|
|
226
273
|
|
|
227
|
-
#
|
|
274
|
+
# 2. CHECK FOR FONT CLASH
|
|
275
|
+
# If matrix is huge, force text smaller. If small, allow user config.
|
|
276
|
+
final_font_size = cm_font_size + 2
|
|
277
|
+
if n_classes > 2:
|
|
278
|
+
final_font_size = cm_font_size - n_classes # Decrease font size for larger matrices
|
|
279
|
+
|
|
228
280
|
for text in ax_cm.texts:
|
|
229
|
-
text.set_fontsize(
|
|
281
|
+
text.set_fontsize(final_font_size)
|
|
282
|
+
|
|
283
|
+
# Update Ticks for Confusion Matrix
|
|
284
|
+
ax_cm.tick_params(axis='x', labelsize=cm_tick_size)
|
|
285
|
+
ax_cm.tick_params(axis='y', labelsize=cm_tick_size)
|
|
286
|
+
|
|
287
|
+
#if more than 3 classes, rotate x ticks
|
|
288
|
+
if n_classes > 3:
|
|
289
|
+
plt.setp(ax_cm.get_xticklabels(), rotation=45, ha='right', rotation_mode="anchor")
|
|
230
290
|
|
|
291
|
+
# Set titles and labels with padding
|
|
292
|
+
ax_cm.set_title("Confusion Matrix", pad=_EvaluationConfig.LABEL_PADDING, fontsize=cm_font_size + 2)
|
|
293
|
+
ax_cm.set_xlabel(ax_cm.get_xlabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=cm_font_size)
|
|
294
|
+
ax_cm.set_ylabel(ax_cm.get_ylabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=cm_font_size)
|
|
295
|
+
|
|
296
|
+
# --- ADJUST COLORBAR FONT & SIZE---
|
|
297
|
+
# Manually add the colorbar with the 'shrink' parameter
|
|
298
|
+
cbar = fig_cm.colorbar(disp_.im_, ax=ax_cm, shrink=0.8)
|
|
299
|
+
|
|
300
|
+
# Update the tick size on the new cbar object
|
|
301
|
+
cbar.ax.tick_params(labelsize=cm_tick_size)
|
|
302
|
+
|
|
303
|
+
# (Optional) add a label to the bar itself (e.g. "Probability")
|
|
304
|
+
# cbar.set_label('Probability', fontsize=12)
|
|
305
|
+
|
|
231
306
|
fig_cm.tight_layout()
|
|
232
307
|
|
|
233
|
-
ax_cm.set_title("Confusion Matrix")
|
|
234
308
|
cm_path = save_dir_path / "confusion_matrix.svg"
|
|
235
309
|
plt.savefig(cm_path)
|
|
236
310
|
_LOGGER.info(f"❇️ Confusion matrix saved as '{cm_path.name}'")
|
|
@@ -335,34 +409,50 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
335
409
|
# Calculate AUC.
|
|
336
410
|
auc = roc_auc_score(y_true_binary, y_score)
|
|
337
411
|
|
|
338
|
-
fig_roc, ax_roc = plt.subplots(figsize=
|
|
412
|
+
fig_roc, ax_roc = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
|
|
339
413
|
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=format_config.ROC_PR_line)
|
|
340
414
|
ax_roc.plot([0, 1], [0, 1], 'k--')
|
|
341
|
-
ax_roc.set_title(f'Receiver Operating Characteristic{plot_title}')
|
|
342
|
-
ax_roc.set_xlabel('False Positive Rate')
|
|
343
|
-
ax_roc.set_ylabel('True Positive Rate')
|
|
344
|
-
|
|
415
|
+
ax_roc.set_title(f'Receiver Operating Characteristic{plot_title}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size + 2)
|
|
416
|
+
ax_roc.set_xlabel('False Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
|
|
417
|
+
ax_roc.set_ylabel('True Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
|
|
418
|
+
|
|
419
|
+
# Apply Ticks and Legend sizing
|
|
420
|
+
ax_roc.tick_params(axis='x', labelsize=xtick_size)
|
|
421
|
+
ax_roc.tick_params(axis='y', labelsize=ytick_size)
|
|
422
|
+
ax_roc.legend(loc='lower right', fontsize=legend_size)
|
|
423
|
+
|
|
345
424
|
ax_roc.grid(True)
|
|
346
425
|
roc_path = save_dir_path / f"roc_curve{save_suffix}.svg"
|
|
426
|
+
|
|
427
|
+
plt.tight_layout()
|
|
428
|
+
|
|
347
429
|
plt.savefig(roc_path)
|
|
348
430
|
plt.close(fig_roc)
|
|
349
431
|
|
|
350
432
|
# --- Save Precision-Recall Curve ---
|
|
351
433
|
precision, recall, _ = precision_recall_curve(y_true_binary, y_score)
|
|
352
434
|
ap_score = average_precision_score(y_true_binary, y_score)
|
|
353
|
-
fig_pr, ax_pr = plt.subplots(figsize=
|
|
435
|
+
fig_pr, ax_pr = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
|
|
354
436
|
ax_pr.plot(recall, precision, label=f'Avg Precision = {ap_score:.2f}', color=format_config.ROC_PR_line)
|
|
355
|
-
ax_pr.set_title(f'Precision-Recall Curve{plot_title}')
|
|
356
|
-
ax_pr.set_xlabel('Recall')
|
|
357
|
-
ax_pr.set_ylabel('Precision')
|
|
358
|
-
|
|
437
|
+
ax_pr.set_title(f'Precision-Recall Curve{plot_title}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size + 2)
|
|
438
|
+
ax_pr.set_xlabel('Recall', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
|
|
439
|
+
ax_pr.set_ylabel('Precision', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
|
|
440
|
+
|
|
441
|
+
# Apply Ticks and Legend sizing
|
|
442
|
+
ax_pr.tick_params(axis='x', labelsize=xtick_size)
|
|
443
|
+
ax_pr.tick_params(axis='y', labelsize=ytick_size)
|
|
444
|
+
ax_pr.legend(loc='lower left', fontsize=legend_size)
|
|
445
|
+
|
|
359
446
|
ax_pr.grid(True)
|
|
360
447
|
pr_path = save_dir_path / f"pr_curve{save_suffix}.svg"
|
|
448
|
+
|
|
449
|
+
plt.tight_layout()
|
|
450
|
+
|
|
361
451
|
plt.savefig(pr_path)
|
|
362
452
|
plt.close(fig_pr)
|
|
363
453
|
|
|
364
454
|
# --- Save Calibration Plot ---
|
|
365
|
-
fig_cal, ax_cal = plt.subplots(figsize=
|
|
455
|
+
fig_cal, ax_cal = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
|
|
366
456
|
|
|
367
457
|
# --- Step 1: Get binned data *without* plotting ---
|
|
368
458
|
with plt.ioff(): # Suppress showing the temporary plot
|
|
@@ -386,7 +476,7 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
386
476
|
y=line_y,
|
|
387
477
|
ax=ax_cal,
|
|
388
478
|
scatter=False,
|
|
389
|
-
label=f"
|
|
479
|
+
label=f"Model calibration",
|
|
390
480
|
line_kws={
|
|
391
481
|
'color': format_config.ROC_PR_line,
|
|
392
482
|
'linestyle': '--',
|
|
@@ -394,15 +484,19 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
394
484
|
}
|
|
395
485
|
)
|
|
396
486
|
|
|
397
|
-
ax_cal.set_title(f'Reliability Curve{plot_title}')
|
|
398
|
-
ax_cal.set_xlabel('Mean Predicted Probability')
|
|
399
|
-
ax_cal.set_ylabel('Fraction of Positives')
|
|
487
|
+
ax_cal.set_title(f'Reliability Curve{plot_title}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size + 2)
|
|
488
|
+
ax_cal.set_xlabel('Mean Predicted Probability', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
|
|
489
|
+
ax_cal.set_ylabel('Fraction of Positives', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
|
|
400
490
|
|
|
401
491
|
# --- Step 3: Set final limits *after* plotting ---
|
|
402
492
|
ax_cal.set_ylim(0.0, 1.0)
|
|
403
493
|
ax_cal.set_xlim(0.0, 1.0)
|
|
404
494
|
|
|
405
|
-
|
|
495
|
+
# Apply Ticks and Legend sizing
|
|
496
|
+
ax_cal.tick_params(axis='x', labelsize=xtick_size)
|
|
497
|
+
ax_cal.tick_params(axis='y', labelsize=ytick_size)
|
|
498
|
+
ax_cal.legend(loc='lower right', fontsize=legend_size)
|
|
499
|
+
|
|
406
500
|
ax_cal.grid(True)
|
|
407
501
|
plt.tight_layout()
|
|
408
502
|
|
|
@@ -413,7 +507,7 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
413
507
|
_LOGGER.info(f"📈 Saved {len(class_indices_to_plot)} sets of ROC, Precision-Recall, and Calibration plots.")
|
|
414
508
|
|
|
415
509
|
# restore RC params
|
|
416
|
-
plt.rcParams.update(original_rc_params)
|
|
510
|
+
# plt.rcParams.update(original_rc_params)
|
|
417
511
|
|
|
418
512
|
|
|
419
513
|
def regression_metrics(
|
|
@@ -440,8 +534,13 @@ def regression_metrics(
|
|
|
440
534
|
format_config = config
|
|
441
535
|
|
|
442
536
|
# --- Set Matplotlib font size ---
|
|
443
|
-
original_rc_params = plt.rcParams.copy()
|
|
444
|
-
plt.rcParams.update({'font.size': format_config.font_size})
|
|
537
|
+
# original_rc_params = plt.rcParams.copy()
|
|
538
|
+
# plt.rcParams.update({'font.size': format_config.font_size})
|
|
539
|
+
|
|
540
|
+
# --- Resolve Font Sizes ---
|
|
541
|
+
xtick_size = format_config.xtick_size
|
|
542
|
+
ytick_size = format_config.ytick_size
|
|
543
|
+
base_font_size = format_config.font_size
|
|
445
544
|
|
|
446
545
|
# --- Calculate Metrics ---
|
|
447
546
|
rmse = np.sqrt(mean_squared_error(y_true, y_pred))
|
|
@@ -472,9 +571,14 @@ def regression_metrics(
|
|
|
472
571
|
alpha=format_config.scatter_alpha,
|
|
473
572
|
color=format_config.scatter_color)
|
|
474
573
|
ax_res.axhline(0, color=format_config.residual_line_color, linestyle='--')
|
|
475
|
-
ax_res.set_xlabel("Predicted Values")
|
|
476
|
-
ax_res.set_ylabel("Residuals")
|
|
477
|
-
ax_res.set_title("Residual Plot")
|
|
574
|
+
ax_res.set_xlabel("Predicted Values", labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
575
|
+
ax_res.set_ylabel("Residuals", labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
576
|
+
ax_res.set_title("Residual Plot", pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
|
|
577
|
+
|
|
578
|
+
# Apply Ticks
|
|
579
|
+
ax_res.tick_params(axis='x', labelsize=xtick_size)
|
|
580
|
+
ax_res.tick_params(axis='y', labelsize=ytick_size)
|
|
581
|
+
|
|
478
582
|
ax_res.grid(True)
|
|
479
583
|
plt.tight_layout()
|
|
480
584
|
res_path = save_dir_path / "residual_plot.svg"
|
|
@@ -491,9 +595,14 @@ def regression_metrics(
|
|
|
491
595
|
linestyle='--',
|
|
492
596
|
lw=2,
|
|
493
597
|
color=format_config.ideal_line_color)
|
|
494
|
-
ax_tvp.set_xlabel('True Values')
|
|
495
|
-
ax_tvp.set_ylabel('Predictions')
|
|
496
|
-
ax_tvp.set_title('True vs. Predicted Values')
|
|
598
|
+
ax_tvp.set_xlabel('True Values', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
599
|
+
ax_tvp.set_ylabel('Predictions', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
600
|
+
ax_tvp.set_title('True vs. Predicted Values', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
|
|
601
|
+
|
|
602
|
+
# Apply Ticks
|
|
603
|
+
ax_tvp.tick_params(axis='x', labelsize=xtick_size)
|
|
604
|
+
ax_tvp.tick_params(axis='y', labelsize=ytick_size)
|
|
605
|
+
|
|
497
606
|
ax_tvp.grid(True)
|
|
498
607
|
plt.tight_layout()
|
|
499
608
|
tvp_path = save_dir_path / "true_vs_predicted_plot.svg"
|
|
@@ -506,9 +615,14 @@ def regression_metrics(
|
|
|
506
615
|
sns.histplot(residuals, kde=True, ax=ax_hist,
|
|
507
616
|
bins=format_config.hist_bins,
|
|
508
617
|
color=format_config.scatter_color)
|
|
509
|
-
ax_hist.set_xlabel("Residual Value")
|
|
510
|
-
ax_hist.set_ylabel("Frequency")
|
|
511
|
-
ax_hist.set_title("Distribution of Residuals")
|
|
618
|
+
ax_hist.set_xlabel("Residual Value", labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
619
|
+
ax_hist.set_ylabel("Frequency", labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
620
|
+
ax_hist.set_title("Distribution of Residuals", pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
|
|
621
|
+
|
|
622
|
+
# Apply Ticks
|
|
623
|
+
ax_hist.tick_params(axis='x', labelsize=xtick_size)
|
|
624
|
+
ax_hist.tick_params(axis='y', labelsize=ytick_size)
|
|
625
|
+
|
|
512
626
|
ax_hist.grid(True)
|
|
513
627
|
plt.tight_layout()
|
|
514
628
|
hist_path = save_dir_path / "residuals_histogram.svg"
|
|
@@ -517,7 +631,7 @@ def regression_metrics(
|
|
|
517
631
|
plt.close(fig_hist)
|
|
518
632
|
|
|
519
633
|
# --- Restore RC params ---
|
|
520
|
-
plt.rcParams.update(original_rc_params)
|
|
634
|
+
# plt.rcParams.update(original_rc_params)
|
|
521
635
|
|
|
522
636
|
|
|
523
637
|
def shap_summary_plot(model,
|
|
@@ -44,6 +44,7 @@ __all__ = [
|
|
|
44
44
|
|
|
45
45
|
DPI_value = _EvaluationConfig.DPI
|
|
46
46
|
REGRESSION_PLOT_SIZE = _EvaluationConfig.REGRESSION_PLOT_SIZE
|
|
47
|
+
CLASSIFICATION_PLOT_SIZE = _EvaluationConfig.CLASSIFICATION_PLOT_SIZE
|
|
47
48
|
|
|
48
49
|
|
|
49
50
|
def multi_target_regression_metrics(
|
|
@@ -88,8 +89,13 @@ def multi_target_regression_metrics(
|
|
|
88
89
|
format_config = config
|
|
89
90
|
|
|
90
91
|
# --- Set Matplotlib font size ---
|
|
91
|
-
original_rc_params = plt.rcParams.copy()
|
|
92
|
-
plt.rcParams.update({'font.size': format_config.font_size})
|
|
92
|
+
# original_rc_params = plt.rcParams.copy()
|
|
93
|
+
# plt.rcParams.update({'font.size': format_config.font_size})
|
|
94
|
+
|
|
95
|
+
# ticks font sizes
|
|
96
|
+
xtick_size = format_config.xtick_size
|
|
97
|
+
ytick_size = format_config.ytick_size
|
|
98
|
+
base_font_size = format_config.font_size
|
|
93
99
|
|
|
94
100
|
_LOGGER.debug("--- Multi-Target Regression Evaluation ---")
|
|
95
101
|
|
|
@@ -105,11 +111,11 @@ def multi_target_regression_metrics(
|
|
|
105
111
|
r2 = r2_score(true_i, pred_i)
|
|
106
112
|
medae = median_absolute_error(true_i, pred_i)
|
|
107
113
|
metrics_summary.append({
|
|
108
|
-
'
|
|
109
|
-
'
|
|
110
|
-
'
|
|
111
|
-
'
|
|
112
|
-
'
|
|
114
|
+
'Target': name,
|
|
115
|
+
'RMSE': rmse,
|
|
116
|
+
'MAE': mae,
|
|
117
|
+
'MedAE': medae,
|
|
118
|
+
'R2-score': r2,
|
|
113
119
|
})
|
|
114
120
|
|
|
115
121
|
# --- Save Residual Plot ---
|
|
@@ -121,9 +127,14 @@ def multi_target_regression_metrics(
|
|
|
121
127
|
s=50,
|
|
122
128
|
color=format_config.scatter_color) # Use config color
|
|
123
129
|
ax_res.axhline(0, color=format_config.residual_line_color, linestyle='--') # Use config color
|
|
124
|
-
ax_res.set_xlabel("Predicted Values")
|
|
125
|
-
ax_res.set_ylabel("Residuals")
|
|
126
|
-
ax_res.set_title(f"Residual Plot for '{name}'")
|
|
130
|
+
ax_res.set_xlabel("Predicted Values", labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
131
|
+
ax_res.set_ylabel("Residuals", labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
132
|
+
ax_res.set_title(f"Residual Plot for '{name}'", pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
|
|
133
|
+
|
|
134
|
+
# Apply Ticks
|
|
135
|
+
ax_res.tick_params(axis='x', labelsize=xtick_size)
|
|
136
|
+
ax_res.tick_params(axis='y', labelsize=ytick_size)
|
|
137
|
+
|
|
127
138
|
ax_res.grid(True, linestyle='--', alpha=0.6)
|
|
128
139
|
plt.tight_layout()
|
|
129
140
|
res_path = save_dir_path / f"residual_plot_{sanitized_name}.svg"
|
|
@@ -141,9 +152,14 @@ def multi_target_regression_metrics(
|
|
|
141
152
|
linestyle='--',
|
|
142
153
|
lw=2,
|
|
143
154
|
color=format_config.ideal_line_color) # Use config color
|
|
144
|
-
ax_tvp.set_xlabel('True Values')
|
|
145
|
-
ax_tvp.set_ylabel('Predicted Values')
|
|
146
|
-
ax_tvp.set_title(f"True vs. Predicted for '{name}'")
|
|
155
|
+
ax_tvp.set_xlabel('True Values', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
156
|
+
ax_tvp.set_ylabel('Predicted Values', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
157
|
+
ax_tvp.set_title(f"True vs. Predicted for '{name}'", pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
|
|
158
|
+
|
|
159
|
+
# Apply Ticks
|
|
160
|
+
ax_tvp.tick_params(axis='x', labelsize=xtick_size)
|
|
161
|
+
ax_tvp.tick_params(axis='y', labelsize=ytick_size)
|
|
162
|
+
|
|
147
163
|
ax_tvp.grid(True, linestyle='--', alpha=0.6)
|
|
148
164
|
plt.tight_layout()
|
|
149
165
|
tvp_path = save_dir_path / f"true_vs_predicted_plot_{sanitized_name}.svg"
|
|
@@ -157,7 +173,7 @@ def multi_target_regression_metrics(
|
|
|
157
173
|
_LOGGER.info(f"Full regression report saved to '{report_path.name}'")
|
|
158
174
|
|
|
159
175
|
# --- Restore RC params ---
|
|
160
|
-
plt.rcParams.update(original_rc_params)
|
|
176
|
+
# plt.rcParams.update(original_rc_params)
|
|
161
177
|
|
|
162
178
|
|
|
163
179
|
def multi_label_classification_metrics(
|
|
@@ -205,10 +221,14 @@ def multi_label_classification_metrics(
|
|
|
205
221
|
# y_pred is now passed in directly, no threshold needed.
|
|
206
222
|
|
|
207
223
|
# --- Save current RC params and update font size ---
|
|
208
|
-
original_rc_params = plt.rcParams.copy()
|
|
209
|
-
plt.rcParams.update({'font.size': format_config.font_size})
|
|
224
|
+
# original_rc_params = plt.rcParams.copy()
|
|
225
|
+
# plt.rcParams.update({'font.size': format_config.font_size})
|
|
210
226
|
|
|
211
|
-
#
|
|
227
|
+
# ticks and legend font sizes
|
|
228
|
+
xtick_size = format_config.xtick_size
|
|
229
|
+
ytick_size = format_config.ytick_size
|
|
230
|
+
legend_size = format_config.legend_size
|
|
231
|
+
base_font_size = format_config.font_size
|
|
212
232
|
|
|
213
233
|
# --- Calculate and Save Overall Metrics (using y_pred) ---
|
|
214
234
|
h_loss = hamming_loss(y_true, y_pred)
|
|
@@ -224,7 +244,7 @@ def multi_label_classification_metrics(
|
|
|
224
244
|
f"--------------------------------------------------\n"
|
|
225
245
|
)
|
|
226
246
|
# print(overall_report)
|
|
227
|
-
overall_report_path = save_dir_path / "
|
|
247
|
+
overall_report_path = save_dir_path / "classification_report.txt"
|
|
228
248
|
overall_report_path.write_text(overall_report)
|
|
229
249
|
|
|
230
250
|
# --- Per-Label Metrics and Plots ---
|
|
@@ -241,14 +261,15 @@ def multi_label_classification_metrics(
|
|
|
241
261
|
report_path.write_text(report_text) # type: ignore
|
|
242
262
|
|
|
243
263
|
# --- Save Confusion Matrix (uses y_pred) ---
|
|
244
|
-
fig_cm, ax_cm = plt.subplots(figsize=
|
|
264
|
+
fig_cm, ax_cm = plt.subplots(figsize=_EvaluationConfig.CM_SIZE, dpi=_EvaluationConfig.DPI)
|
|
245
265
|
disp_ = ConfusionMatrixDisplay.from_predictions(true_i,
|
|
246
266
|
pred_i,
|
|
247
267
|
cmap=format_config.cmap, # Use config cmap
|
|
248
268
|
ax=ax_cm,
|
|
249
269
|
normalize='true',
|
|
250
270
|
labels=[0, 1],
|
|
251
|
-
display_labels=["Negative", "Positive"]
|
|
271
|
+
display_labels=["Negative", "Positive"],
|
|
272
|
+
colorbar=False)
|
|
252
273
|
|
|
253
274
|
disp_.im_.set_clim(vmin=0.0, vmax=1.0)
|
|
254
275
|
|
|
@@ -257,11 +278,26 @@ def multi_label_classification_metrics(
|
|
|
257
278
|
|
|
258
279
|
# Manually update font size of cell texts
|
|
259
280
|
for text in ax_cm.texts:
|
|
260
|
-
text.set_fontsize(
|
|
281
|
+
text.set_fontsize(base_font_size + 2) # Use config font_size
|
|
282
|
+
|
|
283
|
+
# Apply ticks
|
|
284
|
+
ax_cm.tick_params(axis='x', labelsize=xtick_size)
|
|
285
|
+
ax_cm.tick_params(axis='y', labelsize=ytick_size)
|
|
286
|
+
|
|
287
|
+
# Set titles and labels with padding
|
|
288
|
+
ax_cm.set_title(f"Confusion Matrix for '{name}'", pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
|
|
289
|
+
ax_cm.set_xlabel(ax_cm.get_xlabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
290
|
+
ax_cm.set_ylabel(ax_cm.get_ylabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
291
|
+
|
|
292
|
+
# --- ADJUST COLORBAR FONT & SIZE---
|
|
293
|
+
# Manually add the colorbar with the 'shrink' parameter
|
|
294
|
+
cbar = fig_cm.colorbar(disp_.im_, ax=ax_cm, shrink=0.8)
|
|
261
295
|
|
|
262
|
-
|
|
296
|
+
# Update the tick size on the new cbar object
|
|
297
|
+
cbar.ax.tick_params(labelsize=ytick_size) # type: ignore
|
|
298
|
+
|
|
299
|
+
plt.tight_layout()
|
|
263
300
|
|
|
264
|
-
ax_cm.set_title(f"Confusion Matrix for '{name}'")
|
|
265
301
|
cm_path = save_dir_path / f"confusion_matrix_{sanitized_name}.svg"
|
|
266
302
|
plt.savefig(cm_path)
|
|
267
303
|
plt.close(fig_cm)
|
|
@@ -302,12 +338,23 @@ def multi_label_classification_metrics(
|
|
|
302
338
|
_LOGGER.warning(f"Could not calculate or save optimal threshold for '{name}': {e}")
|
|
303
339
|
|
|
304
340
|
auc = roc_auc_score(true_i, prob_i)
|
|
305
|
-
fig_roc, ax_roc = plt.subplots(figsize=
|
|
341
|
+
fig_roc, ax_roc = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
|
|
306
342
|
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=format_config.ROC_PR_line) # Use config color
|
|
307
343
|
ax_roc.plot([0, 1], [0, 1], 'k--')
|
|
308
|
-
|
|
309
|
-
ax_roc.
|
|
310
|
-
ax_roc.
|
|
344
|
+
|
|
345
|
+
ax_roc.set_title(f'ROC Curve for "{name}"', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
|
|
346
|
+
ax_roc.set_xlabel('False Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
347
|
+
ax_roc.set_ylabel('True Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
348
|
+
|
|
349
|
+
# Apply ticks and legend font size
|
|
350
|
+
ax_roc.tick_params(axis='x', labelsize=xtick_size)
|
|
351
|
+
ax_roc.tick_params(axis='y', labelsize=ytick_size)
|
|
352
|
+
ax_roc.legend(loc='lower right', fontsize=legend_size)
|
|
353
|
+
|
|
354
|
+
ax_roc.grid(True, linestyle='--', alpha=0.6)
|
|
355
|
+
|
|
356
|
+
plt.tight_layout()
|
|
357
|
+
|
|
311
358
|
roc_path = save_dir_path / f"roc_curve_{sanitized_name}.svg"
|
|
312
359
|
plt.savefig(roc_path)
|
|
313
360
|
plt.close(fig_roc)
|
|
@@ -315,17 +362,27 @@ def multi_label_classification_metrics(
|
|
|
315
362
|
# --- Save Precision-Recall Curve (uses y_prob) ---
|
|
316
363
|
precision, recall, _ = precision_recall_curve(true_i, prob_i)
|
|
317
364
|
ap_score = average_precision_score(true_i, prob_i)
|
|
318
|
-
fig_pr, ax_pr = plt.subplots(figsize=
|
|
365
|
+
fig_pr, ax_pr = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
|
|
319
366
|
ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=format_config.ROC_PR_line) # Use config color
|
|
320
|
-
ax_pr.set_title(f'Precision-Recall Curve for "{name}"')
|
|
321
|
-
ax_pr.set_xlabel('Recall'
|
|
322
|
-
ax_pr.
|
|
367
|
+
ax_pr.set_title(f'Precision-Recall Curve for "{name}"', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
|
|
368
|
+
ax_pr.set_xlabel('Recall', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
369
|
+
ax_pr.set_ylabel('Precision', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
370
|
+
|
|
371
|
+
# Apply ticks and legend font size
|
|
372
|
+
ax_pr.tick_params(axis='x', labelsize=xtick_size)
|
|
373
|
+
ax_pr.tick_params(axis='y', labelsize=ytick_size)
|
|
374
|
+
ax_pr.legend(loc='lower left', fontsize=legend_size)
|
|
375
|
+
|
|
376
|
+
ax_pr.grid(True, linestyle='--', alpha=0.6)
|
|
377
|
+
|
|
378
|
+
fig_pr.tight_layout()
|
|
379
|
+
|
|
323
380
|
pr_path = save_dir_path / f"pr_curve_{sanitized_name}.svg"
|
|
324
381
|
plt.savefig(pr_path)
|
|
325
382
|
plt.close(fig_pr)
|
|
326
383
|
|
|
327
384
|
# restore RC params
|
|
328
|
-
plt.rcParams.update(original_rc_params)
|
|
385
|
+
# plt.rcParams.update(original_rc_params)
|
|
329
386
|
|
|
330
387
|
_LOGGER.info(f"All individual label reports and plots saved to '{save_dir_path.name}'")
|
|
331
388
|
|
|
@@ -186,7 +186,7 @@ def sequence_to_sequence_metrics(
|
|
|
186
186
|
_LOGGER.info(f"📝 Seq-to-Seq per-step report saved as '{report_path.name}'")
|
|
187
187
|
|
|
188
188
|
# --- Create and save plot ---
|
|
189
|
-
fig, ax1 = plt.subplots(figsize=
|
|
189
|
+
fig, ax1 = plt.subplots(figsize=SEQUENCE_PLOT_SIZE, dpi=DPI_value)
|
|
190
190
|
|
|
191
191
|
# Plot RMSE
|
|
192
192
|
color_rmse = format_config.rmse_color
|
|
@@ -165,7 +165,7 @@ def segmentation_metrics(
|
|
|
165
165
|
cm = confusion_matrix(y_true_flat, y_pred_flat, labels=labels)
|
|
166
166
|
|
|
167
167
|
# Plot
|
|
168
|
-
fig_cm, ax_cm = plt.subplots(figsize=(max(8, len(labels) * 0.8), max(8, len(labels) * 0.8)), dpi=
|
|
168
|
+
fig_cm, ax_cm = plt.subplots(figsize=(max(8, len(labels) * 0.8), max(8, len(labels) * 0.8)), dpi=DPI_value)
|
|
169
169
|
disp = ConfusionMatrixDisplay(
|
|
170
170
|
confusion_matrix=cm,
|
|
171
171
|
display_labels=display_names
|
ml_tools/_core/_keys.py
CHANGED
|
@@ -223,10 +223,19 @@ class SchemaKeys:
|
|
|
223
223
|
|
|
224
224
|
class _EvaluationConfig:
|
|
225
225
|
"""Set config values for evaluation modules."""
|
|
226
|
-
DPI =
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
226
|
+
DPI = 400
|
|
227
|
+
LABEL_PADDING = 10
|
|
228
|
+
# large sizes for SVG layout to accommodate large fonts
|
|
229
|
+
REGRESSION_PLOT_SIZE = (12, 8)
|
|
230
|
+
SEQUENCE_PLOT_SIZE = (12, 8)
|
|
231
|
+
CLASSIFICATION_PLOT_SIZE = (10, 10)
|
|
232
|
+
# Loss plot
|
|
233
|
+
LOSS_PLOT_SIZE = (18, 9)
|
|
234
|
+
LOSS_PLOT_LABEL_SIZE = 24
|
|
235
|
+
LOSS_PLOT_TICK_SIZE = 22
|
|
236
|
+
LOSS_PLOT_LEGEND_SIZE = 24
|
|
237
|
+
# CM settings
|
|
238
|
+
CM_SIZE = (9, 8) # used for multi label binary classification confusion matrix
|
|
230
239
|
|
|
231
240
|
class _OneHotOtherPlaceholder:
|
|
232
241
|
"""Used internally by GUI_tools."""
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|