dragon-ml-toolbox 14.3.1__py3-none-any.whl → 14.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-14.8.0.dist-info}/METADATA +2 -1
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-14.8.0.dist-info}/RECORD +17 -16
- ml_tools/ML_configuration.py +116 -0
- ml_tools/ML_datasetmaster.py +42 -0
- ml_tools/ML_evaluation.py +208 -63
- ml_tools/ML_evaluation_multi.py +40 -10
- ml_tools/ML_trainer.py +38 -12
- ml_tools/ML_utilities.py +50 -1
- ml_tools/ML_vision_datasetmaster.py +198 -60
- ml_tools/ML_vision_models.py +15 -1
- ml_tools/ML_vision_transformers.py +151 -6
- ml_tools/ensemble_evaluation.py +53 -10
- ml_tools/keys.py +2 -1
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-14.8.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-14.8.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-14.8.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-14.8.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_evaluation_multi.py
CHANGED
|
@@ -34,6 +34,8 @@ __all__ = [
|
|
|
34
34
|
"multi_target_shap_summary_plot",
|
|
35
35
|
]
|
|
36
36
|
|
|
37
|
+
DPI_value = 250
|
|
38
|
+
|
|
37
39
|
|
|
38
40
|
def multi_target_regression_metrics(
|
|
39
41
|
y_true: np.ndarray,
|
|
@@ -90,7 +92,7 @@ def multi_target_regression_metrics(
|
|
|
90
92
|
|
|
91
93
|
# --- Save Residual Plot ---
|
|
92
94
|
residuals = true_i - pred_i
|
|
93
|
-
fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=
|
|
95
|
+
fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
94
96
|
ax_res.scatter(pred_i, residuals, alpha=0.6, edgecolors='k', s=50)
|
|
95
97
|
ax_res.axhline(0, color='red', linestyle='--')
|
|
96
98
|
ax_res.set_xlabel("Predicted Values")
|
|
@@ -103,7 +105,7 @@ def multi_target_regression_metrics(
|
|
|
103
105
|
plt.close(fig_res)
|
|
104
106
|
|
|
105
107
|
# --- Save True vs. Predicted Plot ---
|
|
106
|
-
fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=
|
|
108
|
+
fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
107
109
|
ax_tvp.scatter(true_i, pred_i, alpha=0.6, edgecolors='k', s=50)
|
|
108
110
|
ax_tvp.plot([true_i.min(), true_i.max()], [true_i.min(), true_i.max()], 'k--', lw=2)
|
|
109
111
|
ax_tvp.set_xlabel('True Values')
|
|
@@ -127,7 +129,10 @@ def multi_label_classification_metrics(
|
|
|
127
129
|
y_prob: np.ndarray,
|
|
128
130
|
target_names: List[str],
|
|
129
131
|
save_dir: Union[str, Path],
|
|
130
|
-
threshold: float = 0.5
|
|
132
|
+
threshold: float = 0.5,
|
|
133
|
+
ROC_PR_line: str='darkorange',
|
|
134
|
+
cmap: str = "Blues",
|
|
135
|
+
font_size: int = 16
|
|
131
136
|
):
|
|
132
137
|
"""
|
|
133
138
|
Calculates and saves classification metrics for each label individually.
|
|
@@ -158,6 +163,10 @@ def multi_label_classification_metrics(
|
|
|
158
163
|
|
|
159
164
|
# Generate binary predictions from probabilities
|
|
160
165
|
y_pred = (y_prob >= threshold).astype(int)
|
|
166
|
+
|
|
167
|
+
# --- Save current RC params and update font size ---
|
|
168
|
+
original_rc_params = plt.rcParams.copy()
|
|
169
|
+
plt.rcParams.update({'font.size': font_size})
|
|
161
170
|
|
|
162
171
|
_LOGGER.info("--- Multi-Label Classification Evaluation ---")
|
|
163
172
|
|
|
@@ -174,7 +183,7 @@ def multi_label_classification_metrics(
|
|
|
174
183
|
f"Jaccard Score (macro): {j_score_macro:.4f}\n"
|
|
175
184
|
f"--------------------------------------------------\n"
|
|
176
185
|
)
|
|
177
|
-
print(overall_report)
|
|
186
|
+
# print(overall_report)
|
|
178
187
|
overall_report_path = save_dir_path / "classification_report_overall.txt"
|
|
179
188
|
overall_report_path.write_text(overall_report)
|
|
180
189
|
|
|
@@ -192,8 +201,26 @@ def multi_label_classification_metrics(
|
|
|
192
201
|
report_path.write_text(report_text) # type: ignore
|
|
193
202
|
|
|
194
203
|
# --- Save Confusion Matrix ---
|
|
195
|
-
fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=
|
|
196
|
-
ConfusionMatrixDisplay.from_predictions(true_i,
|
|
204
|
+
fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
205
|
+
disp_ = ConfusionMatrixDisplay.from_predictions(true_i,
|
|
206
|
+
pred_i,
|
|
207
|
+
cmap=cmap,
|
|
208
|
+
ax=ax_cm,
|
|
209
|
+
normalize='true',
|
|
210
|
+
labels=[0, 1],
|
|
211
|
+
display_labels=["Negative", "Positive"])
|
|
212
|
+
|
|
213
|
+
disp_.im_.set_clim(vmin=0.0, vmax=1.0)
|
|
214
|
+
|
|
215
|
+
# Turn off gridlines
|
|
216
|
+
ax_cm.grid(False)
|
|
217
|
+
|
|
218
|
+
# Manually update font size of cell texts
|
|
219
|
+
for text in ax_cm.texts:
|
|
220
|
+
text.set_fontsize(font_size)
|
|
221
|
+
|
|
222
|
+
fig_cm.tight_layout()
|
|
223
|
+
|
|
197
224
|
ax_cm.set_title(f"Confusion Matrix for '{name}'")
|
|
198
225
|
cm_path = save_dir_path / f"confusion_matrix_{sanitized_name}.svg"
|
|
199
226
|
plt.savefig(cm_path)
|
|
@@ -202,8 +229,8 @@ def multi_label_classification_metrics(
|
|
|
202
229
|
# --- Save ROC Curve ---
|
|
203
230
|
fpr, tpr, _ = roc_curve(true_i, prob_i)
|
|
204
231
|
auc = roc_auc_score(true_i, prob_i)
|
|
205
|
-
fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=
|
|
206
|
-
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
|
|
232
|
+
fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
233
|
+
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=ROC_PR_line)
|
|
207
234
|
ax_roc.plot([0, 1], [0, 1], 'k--')
|
|
208
235
|
ax_roc.set_title(f'ROC Curve for "{name}"')
|
|
209
236
|
ax_roc.set_xlabel('False Positive Rate'); ax_roc.set_ylabel('True Positive Rate')
|
|
@@ -215,14 +242,17 @@ def multi_label_classification_metrics(
|
|
|
215
242
|
# --- Save Precision-Recall Curve ---
|
|
216
243
|
precision, recall, _ = precision_recall_curve(true_i, prob_i)
|
|
217
244
|
ap_score = average_precision_score(true_i, prob_i)
|
|
218
|
-
fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=
|
|
219
|
-
ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}')
|
|
245
|
+
fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
246
|
+
ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=ROC_PR_line)
|
|
220
247
|
ax_pr.set_title(f'Precision-Recall Curve for "{name}"')
|
|
221
248
|
ax_pr.set_xlabel('Recall'); ax_pr.set_ylabel('Precision')
|
|
222
249
|
ax_pr.legend(loc='lower left'); ax_pr.grid(True, linestyle='--', alpha=0.6)
|
|
223
250
|
pr_path = save_dir_path / f"pr_curve_{sanitized_name}.svg"
|
|
224
251
|
plt.savefig(pr_path)
|
|
225
252
|
plt.close(fig_pr)
|
|
253
|
+
|
|
254
|
+
# restore RC params
|
|
255
|
+
plt.rcParams.update(original_rc_params)
|
|
226
256
|
|
|
227
257
|
_LOGGER.info(f"All individual label reports and plots saved to '{save_dir_path.name}'")
|
|
228
258
|
|
ml_tools/ML_trainer.py
CHANGED
|
@@ -13,11 +13,12 @@ from .keys import PyTorchLogKeys, PyTorchCheckpointKeys, DatasetKeys
|
|
|
13
13
|
from ._logger import _LOGGER
|
|
14
14
|
from .path_manager import make_fullpath
|
|
15
15
|
from .ML_vision_evaluation import segmentation_metrics, object_detection_metrics
|
|
16
|
+
from .ML_configuration import ClassificationMetricsFormat, MultiClassificationMetricsFormat
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
__all__ = [
|
|
19
20
|
"MLTrainer",
|
|
20
|
-
"ObjectDetectionTrainer"
|
|
21
|
+
"ObjectDetectionTrainer",
|
|
21
22
|
]
|
|
22
23
|
|
|
23
24
|
|
|
@@ -334,14 +335,16 @@ class MLTrainer:
|
|
|
334
335
|
|
|
335
336
|
yield y_pred_batch, y_prob_batch, y_true_batch
|
|
336
337
|
|
|
337
|
-
def evaluate(self,
|
|
338
|
+
def evaluate(self,
|
|
339
|
+
save_dir: Union[str, Path],
|
|
340
|
+
data: Optional[Union[DataLoader, Dataset]] = None,
|
|
341
|
+
format_configuration: Optional[Union[ClassificationMetricsFormat, MultiClassificationMetricsFormat]]=None):
|
|
338
342
|
"""
|
|
339
343
|
Evaluates the model, routing to the correct evaluation function based on task `kind`.
|
|
340
344
|
|
|
341
345
|
Args:
|
|
342
346
|
save_dir (str | Path): Directory to save all reports and plots.
|
|
343
347
|
data (DataLoader | Dataset | None): The data to evaluate on. If None, defaults to the trainer's internal test_dataset.
|
|
344
|
-
classification_threshold (float): Probability threshold for multi-label tasks.
|
|
345
348
|
"""
|
|
346
349
|
dataset_for_names = None
|
|
347
350
|
eval_loader = None
|
|
@@ -376,10 +379,10 @@ class MLTrainer:
|
|
|
376
379
|
_LOGGER.error("Cannot evaluate. No valid data was provided or found.")
|
|
377
380
|
raise ValueError()
|
|
378
381
|
|
|
379
|
-
print("\n--- Model Evaluation ---")
|
|
382
|
+
# print("\n--- Model Evaluation ---")
|
|
380
383
|
|
|
381
384
|
all_preds, all_probs, all_true = [], [], []
|
|
382
|
-
for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader
|
|
385
|
+
for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
|
|
383
386
|
if y_pred_b is not None: all_preds.append(y_pred_b)
|
|
384
387
|
if y_prob_b is not None: all_probs.append(y_prob_b)
|
|
385
388
|
if y_true_b is not None: all_true.append(y_true_b)
|
|
@@ -397,7 +400,19 @@ class MLTrainer:
|
|
|
397
400
|
regression_metrics(y_true.flatten(), y_pred.flatten(), save_dir)
|
|
398
401
|
|
|
399
402
|
elif self.kind == "classification":
|
|
400
|
-
|
|
403
|
+
# Parse configuration
|
|
404
|
+
if format_configuration and isinstance(format_configuration, ClassificationMetricsFormat):
|
|
405
|
+
classification_metrics(save_dir=save_dir,
|
|
406
|
+
y_true=y_true,
|
|
407
|
+
y_pred=y_pred,
|
|
408
|
+
y_prob=y_prob,
|
|
409
|
+
cmap=format_configuration.cmap,
|
|
410
|
+
class_map=format_configuration.class_map,
|
|
411
|
+
ROC_PR_line=format_configuration.ROC_PR_line,
|
|
412
|
+
calibration_bins=format_configuration.calibration_bins,
|
|
413
|
+
font_size=format_configuration.font_size)
|
|
414
|
+
else:
|
|
415
|
+
classification_metrics(save_dir, y_true, y_pred, y_prob)
|
|
401
416
|
|
|
402
417
|
elif self.kind == "multi_target_regression":
|
|
403
418
|
try:
|
|
@@ -419,7 +434,18 @@ class MLTrainer:
|
|
|
419
434
|
if y_prob is None:
|
|
420
435
|
_LOGGER.error("Evaluation for multi_label_classification requires probabilities (y_prob).")
|
|
421
436
|
return
|
|
422
|
-
|
|
437
|
+
|
|
438
|
+
if format_configuration and isinstance(format_configuration, MultiClassificationMetricsFormat):
|
|
439
|
+
multi_label_classification_metrics(y_true=y_true,
|
|
440
|
+
y_prob=y_prob,
|
|
441
|
+
target_names=target_names,
|
|
442
|
+
save_dir=save_dir,
|
|
443
|
+
threshold=format_configuration.threshold,
|
|
444
|
+
ROC_PR_line=format_configuration.ROC_PR_line,
|
|
445
|
+
cmap=format_configuration.cmap,
|
|
446
|
+
font_size=format_configuration.font_size)
|
|
447
|
+
else:
|
|
448
|
+
multi_label_classification_metrics(y_true, y_prob, target_names, save_dir)
|
|
423
449
|
|
|
424
450
|
elif self.kind == "segmentation":
|
|
425
451
|
class_names = None
|
|
@@ -445,7 +471,7 @@ class MLTrainer:
|
|
|
445
471
|
|
|
446
472
|
segmentation_metrics(y_true, y_pred, save_dir, class_names=class_names)
|
|
447
473
|
|
|
448
|
-
print("\n--- Training History ---")
|
|
474
|
+
# print("\n--- Training History ---")
|
|
449
475
|
plot_losses(self.history, save_dir=save_dir)
|
|
450
476
|
|
|
451
477
|
def explain(self,
|
|
@@ -502,7 +528,7 @@ class MLTrainer:
|
|
|
502
528
|
rand_indices = torch.randperm(full_data.size(0))[:num_samples]
|
|
503
529
|
return full_data[rand_indices]
|
|
504
530
|
|
|
505
|
-
print(f"\n--- Preparing SHAP Data (sampling up to {n_samples} instances) ---")
|
|
531
|
+
# print(f"\n--- Preparing SHAP Data (sampling up to {n_samples} instances) ---")
|
|
506
532
|
|
|
507
533
|
# 1. Get background data from the trainer's train_dataset
|
|
508
534
|
background_data = _get_random_sample(self.train_dataset, n_samples)
|
|
@@ -610,7 +636,7 @@ class MLTrainer:
|
|
|
610
636
|
plot_n_features (int): Number of top features to plot.
|
|
611
637
|
"""
|
|
612
638
|
|
|
613
|
-
print("\n--- Attention Analysis ---")
|
|
639
|
+
# print("\n--- Attention Analysis ---")
|
|
614
640
|
|
|
615
641
|
# --- Step 1: Check if the model supports this explanation ---
|
|
616
642
|
if not getattr(self.model, 'has_interpretable_attention', False):
|
|
@@ -994,7 +1020,7 @@ class ObjectDetectionTrainer:
|
|
|
994
1020
|
_LOGGER.error("Cannot evaluate. No valid data was provided or found.")
|
|
995
1021
|
raise ValueError()
|
|
996
1022
|
|
|
997
|
-
print("\n--- Model Evaluation ---")
|
|
1023
|
+
# print("\n--- Model Evaluation ---")
|
|
998
1024
|
|
|
999
1025
|
all_predictions = []
|
|
1000
1026
|
all_targets = []
|
|
@@ -1043,7 +1069,7 @@ class ObjectDetectionTrainer:
|
|
|
1043
1069
|
print_output=False
|
|
1044
1070
|
)
|
|
1045
1071
|
|
|
1046
|
-
print("\n--- Training History ---")
|
|
1072
|
+
# print("\n--- Training History ---")
|
|
1047
1073
|
plot_losses(self.history, save_dir=save_dir)
|
|
1048
1074
|
|
|
1049
1075
|
def _callbacks_hook(self, method_name: str, *args, **kwargs):
|
ml_tools/ML_utilities.py
CHANGED
|
@@ -10,6 +10,7 @@ from ._logger import _LOGGER
|
|
|
10
10
|
from .keys import DatasetKeys, PytorchModelArchitectureKeys, PytorchArtifactPathKeys, SHAPKeys, UtilityKeys, PyTorchCheckpointKeys
|
|
11
11
|
from .utilities import load_dataframe
|
|
12
12
|
from .custom_logger import save_list_strings, custom_logger
|
|
13
|
+
from .serde import serialize_object_filename
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
__all__ = [
|
|
@@ -18,7 +19,8 @@ __all__ = [
|
|
|
18
19
|
"get_model_parameters",
|
|
19
20
|
"inspect_model_architecture",
|
|
20
21
|
"inspect_pth_file",
|
|
21
|
-
"set_parameter_requires_grad"
|
|
22
|
+
"set_parameter_requires_grad",
|
|
23
|
+
"save_pretrained_transforms"
|
|
22
24
|
]
|
|
23
25
|
|
|
24
26
|
|
|
@@ -524,5 +526,52 @@ def _set_params_grad(
|
|
|
524
526
|
return params_changed
|
|
525
527
|
|
|
526
528
|
|
|
529
|
+
def save_pretrained_transforms(model: nn.Module, output_dir: Union[str, Path]):
|
|
530
|
+
"""
|
|
531
|
+
Checks a model for the 'self._pretrained_default_transforms' attribute, if found,
|
|
532
|
+
serializes the returned transform object as a .joblib file.
|
|
533
|
+
|
|
534
|
+
This saves the callable transform object itself for
|
|
535
|
+
later use, such as passing it directly to the 'transform_source'
|
|
536
|
+
argument of the PyTorchVisionInferenceHandler.
|
|
537
|
+
|
|
538
|
+
Args:
|
|
539
|
+
model (nn.Module): The model instance to check.
|
|
540
|
+
output_dir (str | Path): The directory where the transform file will be saved.
|
|
541
|
+
"""
|
|
542
|
+
output_filename = "pretrained_model_transformations"
|
|
543
|
+
|
|
544
|
+
# 1. Check for the "secret attribute"
|
|
545
|
+
if not hasattr(model, '_pretrained_default_transforms'):
|
|
546
|
+
_LOGGER.warning(f"Model of type {type(model).__name__} does not have the required attribute. No transformations saved.")
|
|
547
|
+
return
|
|
548
|
+
|
|
549
|
+
# 2. Get the transform object
|
|
550
|
+
try:
|
|
551
|
+
transform_obj = model._pretrained_default_transforms
|
|
552
|
+
except Exception as e:
|
|
553
|
+
_LOGGER.error(f"Error calling the required attribute on model: {e}")
|
|
554
|
+
return
|
|
555
|
+
|
|
556
|
+
# 3. Check if the object is actually there
|
|
557
|
+
if transform_obj is None:
|
|
558
|
+
_LOGGER.warning(f"Model {type(model).__name__} has the required attribute but returned None. No transforms saved.")
|
|
559
|
+
return
|
|
560
|
+
|
|
561
|
+
# 4. Serialize and save using serde
|
|
562
|
+
try:
|
|
563
|
+
serialize_object_filename(
|
|
564
|
+
obj=transform_obj,
|
|
565
|
+
save_dir=output_dir,
|
|
566
|
+
filename=output_filename,
|
|
567
|
+
verbose=True,
|
|
568
|
+
raise_on_error=True
|
|
569
|
+
)
|
|
570
|
+
# _LOGGER.info(f"Successfully saved pretrained transforms to '{output_dir}'.")
|
|
571
|
+
except Exception as e:
|
|
572
|
+
_LOGGER.error(f"Failed to serialize transformations: {e}")
|
|
573
|
+
raise
|
|
574
|
+
|
|
575
|
+
|
|
527
576
|
def info():
|
|
528
577
|
_script_info(__all__)
|