dragon-ml-toolbox 16.2.0__py3-none-any.whl → 16.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-16.2.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/METADATA +1 -1
- {dragon_ml_toolbox-16.2.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/RECORD +8 -8
- ml_tools/ML_configuration.py +44 -30
- ml_tools/ML_evaluation.py +16 -3
- {dragon_ml_toolbox-16.2.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-16.2.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-16.2.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-16.2.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/top_level.txt +0 -0
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
dragon_ml_toolbox-16.2.
|
|
2
|
-
dragon_ml_toolbox-16.2.
|
|
1
|
+
dragon_ml_toolbox-16.2.1.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
|
|
2
|
+
dragon_ml_toolbox-16.2.1.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=gkOdNDbKYpIJezwSo2CEnISkLeYfYHv9t8b5K2-P69A,2687
|
|
3
3
|
ml_tools/ETL_cleaning.py,sha256=Bg0nTmpNzQKDdezK3m0NjYT7N8_ANGlmD9mDXjggqkA,20522
|
|
4
4
|
ml_tools/ETL_engineering.py,sha256=PGXvlvMWa05J1rsMNXxnHzXIe2K68qhtigSn74W8kFI,54961
|
|
5
5
|
ml_tools/GUI_tools.py,sha256=QMSu-8eSNminD6A6Yg9sXo4ff6GNPThwRBVgQQwAAbY,45508
|
|
6
6
|
ml_tools/MICE_imputation.py,sha256=2MsHeKTd8MSBIYmj0q671Fm4wCBvMGjpxULp__jDNgo,20812
|
|
7
7
|
ml_tools/ML_callbacks.py,sha256=EF7Px_IV3IIJpfaT0Nwbv4-_0C6IUlJ_xjzHOekXwq0,16410
|
|
8
|
-
ml_tools/ML_configuration.py,sha256=
|
|
8
|
+
ml_tools/ML_configuration.py,sha256=W4KY4SrpIQAKCmLfVntTWW8fsEuVpHz-CXf_rnNNGqM,31905
|
|
9
9
|
ml_tools/ML_datasetmaster.py,sha256=isvRXI8vNRTFNCFFFpGtsUA8hS6ZDNezLuDpKd9VU9c,28514
|
|
10
|
-
ml_tools/ML_evaluation.py,sha256=
|
|
10
|
+
ml_tools/ML_evaluation.py,sha256=eFrOCmETRr1FnbwPk6fbflNXEBLqnnLBWjAI5LmF3dg,30576
|
|
11
11
|
ml_tools/ML_evaluation_multi.py,sha256=mEN8jKaU1N7UdgldEykqME0MV_yubojD1StyQC5bFEA,20416
|
|
12
12
|
ml_tools/ML_inference.py,sha256=qxoeurcqp-soapfgHUuzt-NFg0KGwg_wOIuzsRMyJqQ,29447
|
|
13
13
|
ml_tools/ML_models.py,sha256=OEiuUduu2KqsfXZIfzJHR3uop_Zo6dzdKtvaOeRt1G0,27932
|
|
@@ -45,7 +45,7 @@ ml_tools/optimization_tools.py,sha256=_sCLZy9LRIIqt1zkYyKNsSbDK3JjRIhC-sADq-Jteg
|
|
|
45
45
|
ml_tools/path_manager.py,sha256=2lTnhfDNdYlrqP_LGDoP51LdUf9hlTsZKuZJoYq5W-U,18462
|
|
46
46
|
ml_tools/serde.py,sha256=c8uDYjYry_VrLvoG4ixqDj5pij88lVn6Tu4NHcPkwDU,6943
|
|
47
47
|
ml_tools/utilities.py,sha256=wFwdv7xFV8Sv6kNy4_tE7RNasRs_318Zm7s65Uwu2Us,22509
|
|
48
|
-
dragon_ml_toolbox-16.2.
|
|
49
|
-
dragon_ml_toolbox-16.2.
|
|
50
|
-
dragon_ml_toolbox-16.2.
|
|
51
|
-
dragon_ml_toolbox-16.2.
|
|
48
|
+
dragon_ml_toolbox-16.2.1.dist-info/METADATA,sha256=Rr23JuJbUJyYhM-GugIGAjSEde7hj7HL3mCppKU1zCA,6591
|
|
49
|
+
dragon_ml_toolbox-16.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
50
|
+
dragon_ml_toolbox-16.2.1.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
|
|
51
|
+
dragon_ml_toolbox-16.2.1.dist-info/RECORD,,
|
ml_tools/ML_configuration.py
CHANGED
|
@@ -39,7 +39,7 @@ class _BaseClassificationFormat:
|
|
|
39
39
|
[PRIVATE] Base configuration for single-label classification metrics.
|
|
40
40
|
"""
|
|
41
41
|
def __init__(self,
|
|
42
|
-
cmap: str="
|
|
42
|
+
cmap: str="BuGn",
|
|
43
43
|
ROC_PR_line: str='darkorange',
|
|
44
44
|
calibration_bins: int=15,
|
|
45
45
|
font_size: int=16) -> None:
|
|
@@ -64,7 +64,11 @@ class _BaseClassificationFormat:
|
|
|
64
64
|
|
|
65
65
|
<br>
|
|
66
66
|
|
|
67
|
-
|
|
67
|
+
### [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
|
|
68
|
+
|
|
69
|
+
<br>
|
|
70
|
+
|
|
71
|
+
### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
|
|
68
72
|
"""
|
|
69
73
|
self.cmap = cmap
|
|
70
74
|
self.ROC_PR_line = ROC_PR_line
|
|
@@ -86,29 +90,33 @@ class _BaseMultiLabelFormat:
|
|
|
86
90
|
[PRIVATE] Base configuration for multi-label binary classification metrics.
|
|
87
91
|
"""
|
|
88
92
|
def __init__(self,
|
|
93
|
+
cmap: str = "BuGn",
|
|
89
94
|
ROC_PR_line: str='darkorange',
|
|
90
|
-
cmap: str = "Blues",
|
|
91
95
|
font_size: int = 16) -> None:
|
|
92
96
|
"""
|
|
93
97
|
Initializes the formatting configuration for multi-label classification metrics.
|
|
94
98
|
|
|
95
99
|
Args:
|
|
100
|
+
cmap (str): The matplotlib colormap name for the per-label
|
|
101
|
+
confusion matrices. Defaults to "Blues".
|
|
102
|
+
- Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges', 'Purples'
|
|
103
|
+
- Diverging options: 'coolwarm', 'viridis', 'plasma', 'inferno'
|
|
104
|
+
|
|
96
105
|
ROC_PR_line (str): The color name or hex code for the line plotted
|
|
97
106
|
on the ROC and Precision-Recall curves (one for each label).
|
|
98
107
|
Defaults to 'darkorange'.
|
|
99
108
|
- Common color names: 'darkorange', 'cornflowerblue', 'crimson', 'forestgreen'
|
|
100
109
|
- Hex codes: '#FF6347', '#4682B4'
|
|
101
110
|
|
|
102
|
-
cmap (str): The matplotlib colormap name for the per-label
|
|
103
|
-
confusion matrices. Defaults to "Blues".
|
|
104
|
-
- Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges', 'Purples'
|
|
105
|
-
- Diverging options: 'coolwarm', 'viridis', 'plasma', 'inferno'
|
|
106
|
-
|
|
107
111
|
font_size (int): The base font size to apply to the plots. Defaults to 16.
|
|
108
112
|
|
|
109
113
|
<br>
|
|
110
114
|
|
|
111
|
-
|
|
115
|
+
### [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
|
|
116
|
+
|
|
117
|
+
<br>
|
|
118
|
+
|
|
119
|
+
### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
|
|
112
120
|
"""
|
|
113
121
|
self.cmap = cmap
|
|
114
122
|
self.ROC_PR_line = ROC_PR_line
|
|
@@ -116,8 +124,8 @@ class _BaseMultiLabelFormat:
|
|
|
116
124
|
|
|
117
125
|
def __repr__(self) -> str:
|
|
118
126
|
parts = [
|
|
119
|
-
f"ROC_PR_line='{self.ROC_PR_line}'",
|
|
120
127
|
f"cmap='{self.cmap}'",
|
|
128
|
+
f"ROC_PR_line='{self.ROC_PR_line}'",
|
|
121
129
|
f"font_size={self.font_size}"
|
|
122
130
|
]
|
|
123
131
|
return f"{self.__class__.__name__}({', '.join(parts)})"
|
|
@@ -154,7 +162,7 @@ class _BaseRegressionFormat:
|
|
|
154
162
|
|
|
155
163
|
<br>
|
|
156
164
|
|
|
157
|
-
|
|
165
|
+
### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
|
|
158
166
|
"""
|
|
159
167
|
self.font_size = font_size
|
|
160
168
|
self.scatter_color = scatter_color
|
|
@@ -180,8 +188,8 @@ class _BaseSegmentationFormat:
|
|
|
180
188
|
[PRIVATE] Base configuration for segmentation metrics.
|
|
181
189
|
"""
|
|
182
190
|
def __init__(self,
|
|
183
|
-
heatmap_cmap: str =
|
|
184
|
-
cm_cmap: str = "
|
|
191
|
+
heatmap_cmap: str = "BuGn",
|
|
192
|
+
cm_cmap: str = "Purples",
|
|
185
193
|
font_size: int = 16) -> None:
|
|
186
194
|
"""
|
|
187
195
|
Initializes the formatting configuration for segmentation metrics.
|
|
@@ -198,7 +206,7 @@ class _BaseSegmentationFormat:
|
|
|
198
206
|
|
|
199
207
|
<br>
|
|
200
208
|
|
|
201
|
-
|
|
209
|
+
### [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
|
|
202
210
|
"""
|
|
203
211
|
self.heatmap_cmap = heatmap_cmap
|
|
204
212
|
self.cm_cmap = cm_cmap
|
|
@@ -241,10 +249,10 @@ class _BaseSequenceValueFormat:
|
|
|
241
249
|
hist_bins (int | str): The number of bins for the residuals histogram.
|
|
242
250
|
Defaults to 'auto' to use seaborn's automatic bin selection.
|
|
243
251
|
- Options: 'auto', 'sqrt', 10, 20
|
|
244
|
-
|
|
252
|
+
|
|
245
253
|
<br>
|
|
246
254
|
|
|
247
|
-
|
|
255
|
+
### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
|
|
248
256
|
"""
|
|
249
257
|
self.font_size = font_size
|
|
250
258
|
self.scatter_color = scatter_color
|
|
@@ -296,9 +304,15 @@ class _BaseSequenceSequenceFormat:
|
|
|
296
304
|
|
|
297
305
|
<br>
|
|
298
306
|
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
307
|
+
### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
|
|
308
|
+
|
|
309
|
+
<br>
|
|
310
|
+
|
|
311
|
+
### [Matplotlib Linestyles](https://matplotlib.org/stable/gallery/lines_bars_and_markers/linestyles.html)
|
|
312
|
+
|
|
313
|
+
<br>
|
|
314
|
+
|
|
315
|
+
### [Matplotlib Markers](https://matplotlib.org/stable/api/markers_api.html)
|
|
302
316
|
"""
|
|
303
317
|
self.font_size = font_size
|
|
304
318
|
self.plot_figsize = plot_figsize
|
|
@@ -366,7 +380,7 @@ class BinaryClassificationMetricsFormat(_BaseClassificationFormat):
|
|
|
366
380
|
Configuration for binary classification.
|
|
367
381
|
"""
|
|
368
382
|
def __init__(self,
|
|
369
|
-
cmap: str="
|
|
383
|
+
cmap: str="BuGn",
|
|
370
384
|
ROC_PR_line: str='darkorange',
|
|
371
385
|
calibration_bins: int=15,
|
|
372
386
|
font_size: int=16) -> None:
|
|
@@ -381,7 +395,7 @@ class MultiClassClassificationMetricsFormat(_BaseClassificationFormat):
|
|
|
381
395
|
Configuration for multi-class classification.
|
|
382
396
|
"""
|
|
383
397
|
def __init__(self,
|
|
384
|
-
cmap: str="
|
|
398
|
+
cmap: str="BuGn",
|
|
385
399
|
ROC_PR_line: str='darkorange',
|
|
386
400
|
calibration_bins: int=15,
|
|
387
401
|
font_size: int=16) -> None:
|
|
@@ -396,7 +410,7 @@ class BinaryImageClassificationMetricsFormat(_BaseClassificationFormat):
|
|
|
396
410
|
Configuration for binary image classification.
|
|
397
411
|
"""
|
|
398
412
|
def __init__(self,
|
|
399
|
-
cmap: str="
|
|
413
|
+
cmap: str="BuGn",
|
|
400
414
|
ROC_PR_line: str='darkorange',
|
|
401
415
|
calibration_bins: int=15,
|
|
402
416
|
font_size: int=16) -> None:
|
|
@@ -411,7 +425,7 @@ class MultiClassImageClassificationMetricsFormat(_BaseClassificationFormat):
|
|
|
411
425
|
Configuration for multi-class image classification.
|
|
412
426
|
"""
|
|
413
427
|
def __init__(self,
|
|
414
|
-
cmap: str="
|
|
428
|
+
cmap: str="BuGn",
|
|
415
429
|
ROC_PR_line: str='darkorange',
|
|
416
430
|
calibration_bins: int=15,
|
|
417
431
|
font_size: int=16) -> None:
|
|
@@ -427,11 +441,11 @@ class MultiLabelBinaryClassificationMetricsFormat(_BaseMultiLabelFormat):
|
|
|
427
441
|
Configuration for multi-label binary classification.
|
|
428
442
|
"""
|
|
429
443
|
def __init__(self,
|
|
444
|
+
cmap: str = "BuGn",
|
|
430
445
|
ROC_PR_line: str='darkorange',
|
|
431
|
-
cmap: str = "Blues",
|
|
432
446
|
font_size: int = 16) -> None:
|
|
433
|
-
super().__init__(
|
|
434
|
-
|
|
447
|
+
super().__init__(cmap=cmap,
|
|
448
|
+
ROC_PR_line=ROC_PR_line,
|
|
435
449
|
font_size=font_size)
|
|
436
450
|
|
|
437
451
|
|
|
@@ -441,8 +455,8 @@ class BinarySegmentationMetricsFormat(_BaseSegmentationFormat):
|
|
|
441
455
|
Configuration for binary segmentation.
|
|
442
456
|
"""
|
|
443
457
|
def __init__(self,
|
|
444
|
-
heatmap_cmap: str =
|
|
445
|
-
cm_cmap: str = "
|
|
458
|
+
heatmap_cmap: str = "BuGn",
|
|
459
|
+
cm_cmap: str = "Purples",
|
|
446
460
|
font_size: int = 16) -> None:
|
|
447
461
|
super().__init__(heatmap_cmap=heatmap_cmap,
|
|
448
462
|
cm_cmap=cm_cmap,
|
|
@@ -454,8 +468,8 @@ class MultiClassSegmentationMetricsFormat(_BaseSegmentationFormat):
|
|
|
454
468
|
Configuration for multi-class segmentation.
|
|
455
469
|
"""
|
|
456
470
|
def __init__(self,
|
|
457
|
-
heatmap_cmap: str =
|
|
458
|
-
cm_cmap: str = "
|
|
471
|
+
heatmap_cmap: str = "BuGn",
|
|
472
|
+
cm_cmap: str = "Purples",
|
|
459
473
|
font_size: int = 16) -> None:
|
|
460
474
|
super().__init__(heatmap_cmap=heatmap_cmap,
|
|
461
475
|
cm_cmap=cm_cmap,
|
ml_tools/ML_evaluation.py
CHANGED
|
@@ -169,16 +169,29 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
169
169
|
|
|
170
170
|
# --- Save Classification Report Heatmap ---
|
|
171
171
|
try:
|
|
172
|
-
|
|
172
|
+
# Create DataFrame from report
|
|
173
|
+
report_df = pd.DataFrame(report_dict)
|
|
174
|
+
|
|
175
|
+
# 1. Drop the 'accuracy' column (single float)
|
|
176
|
+
if 'accuracy' in report_df.columns:
|
|
177
|
+
report_df = report_df.drop(columns=['accuracy'])
|
|
178
|
+
|
|
179
|
+
# 2. Select all metric rows *except* the last one ('support')
|
|
180
|
+
# 3. Transpose the DataFrame
|
|
181
|
+
plot_df = report_df.iloc[:-1, :].T
|
|
182
|
+
|
|
183
|
+
fig_height = max(5.0, len(plot_df.index) * 0.5 + 2.0)
|
|
184
|
+
plt.figure(figsize=(7, fig_height), dpi=DPI_value)
|
|
185
|
+
|
|
173
186
|
sns.set_theme(font_scale=1.2) # Scale seaborn font
|
|
174
|
-
sns.heatmap(
|
|
187
|
+
sns.heatmap(plot_df,
|
|
175
188
|
annot=True,
|
|
176
189
|
cmap=format_config.cmap,
|
|
177
190
|
fmt='.2f',
|
|
178
191
|
vmin=0.0,
|
|
179
192
|
vmax=1.0)
|
|
180
193
|
sns.set_theme(font_scale=1.0) # Reset seaborn scale
|
|
181
|
-
plt.title("Classification Report")
|
|
194
|
+
plt.title("Classification Report Heatmap")
|
|
182
195
|
plt.tight_layout()
|
|
183
196
|
heatmap_path = save_dir_path / "classification_report_heatmap.svg"
|
|
184
197
|
plt.savefig(heatmap_path)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|