dragon-ml-toolbox 14.3.1__py3-none-any.whl → 14.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -34,6 +34,8 @@ __all__ = [
34
34
  "multi_target_shap_summary_plot",
35
35
  ]
36
36
 
37
+ DPI_value = 250
38
+
37
39
 
38
40
  def multi_target_regression_metrics(
39
41
  y_true: np.ndarray,
@@ -90,7 +92,7 @@ def multi_target_regression_metrics(
90
92
 
91
93
  # --- Save Residual Plot ---
92
94
  residuals = true_i - pred_i
93
- fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=100)
95
+ fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=DPI_value)
94
96
  ax_res.scatter(pred_i, residuals, alpha=0.6, edgecolors='k', s=50)
95
97
  ax_res.axhline(0, color='red', linestyle='--')
96
98
  ax_res.set_xlabel("Predicted Values")
@@ -103,7 +105,7 @@ def multi_target_regression_metrics(
103
105
  plt.close(fig_res)
104
106
 
105
107
  # --- Save True vs. Predicted Plot ---
106
- fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=100)
108
+ fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=DPI_value)
107
109
  ax_tvp.scatter(true_i, pred_i, alpha=0.6, edgecolors='k', s=50)
108
110
  ax_tvp.plot([true_i.min(), true_i.max()], [true_i.min(), true_i.max()], 'k--', lw=2)
109
111
  ax_tvp.set_xlabel('True Values')
@@ -127,7 +129,10 @@ def multi_label_classification_metrics(
127
129
  y_prob: np.ndarray,
128
130
  target_names: List[str],
129
131
  save_dir: Union[str, Path],
130
- threshold: float = 0.5
132
+ threshold: float = 0.5,
133
+ ROC_PR_line: str='darkorange',
134
+ cmap: str = "Blues",
135
+ font_size: int = 16
131
136
  ):
132
137
  """
133
138
  Calculates and saves classification metrics for each label individually.
@@ -158,6 +163,10 @@ def multi_label_classification_metrics(
158
163
 
159
164
  # Generate binary predictions from probabilities
160
165
  y_pred = (y_prob >= threshold).astype(int)
166
+
167
+ # --- Save current RC params and update font size ---
168
+ original_rc_params = plt.rcParams.copy()
169
+ plt.rcParams.update({'font.size': font_size})
161
170
 
162
171
  _LOGGER.info("--- Multi-Label Classification Evaluation ---")
163
172
 
@@ -174,7 +183,7 @@ def multi_label_classification_metrics(
174
183
  f"Jaccard Score (macro): {j_score_macro:.4f}\n"
175
184
  f"--------------------------------------------------\n"
176
185
  )
177
- print(overall_report)
186
+ # print(overall_report)
178
187
  overall_report_path = save_dir_path / "classification_report_overall.txt"
179
188
  overall_report_path.write_text(overall_report)
180
189
 
@@ -192,8 +201,26 @@ def multi_label_classification_metrics(
192
201
  report_path.write_text(report_text) # type: ignore
193
202
 
194
203
  # --- Save Confusion Matrix ---
195
- fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=100)
196
- ConfusionMatrixDisplay.from_predictions(true_i, pred_i, cmap="Blues", ax=ax_cm)
204
+ fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
205
+ disp_ = ConfusionMatrixDisplay.from_predictions(true_i,
206
+ pred_i,
207
+ cmap=cmap,
208
+ ax=ax_cm,
209
+ normalize='true',
210
+ labels=[0, 1],
211
+ display_labels=["Negative", "Positive"])
212
+
213
+ disp_.im_.set_clim(vmin=0.0, vmax=1.0)
214
+
215
+ # Turn off gridlines
216
+ ax_cm.grid(False)
217
+
218
+ # Manually update font size of cell texts
219
+ for text in ax_cm.texts:
220
+ text.set_fontsize(font_size)
221
+
222
+ fig_cm.tight_layout()
223
+
197
224
  ax_cm.set_title(f"Confusion Matrix for '{name}'")
198
225
  cm_path = save_dir_path / f"confusion_matrix_{sanitized_name}.svg"
199
226
  plt.savefig(cm_path)
@@ -202,8 +229,8 @@ def multi_label_classification_metrics(
202
229
  # --- Save ROC Curve ---
203
230
  fpr, tpr, _ = roc_curve(true_i, prob_i)
204
231
  auc = roc_auc_score(true_i, prob_i)
205
- fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=100)
206
- ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
232
+ fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
233
+ ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=ROC_PR_line)
207
234
  ax_roc.plot([0, 1], [0, 1], 'k--')
208
235
  ax_roc.set_title(f'ROC Curve for "{name}"')
209
236
  ax_roc.set_xlabel('False Positive Rate'); ax_roc.set_ylabel('True Positive Rate')
@@ -215,14 +242,17 @@ def multi_label_classification_metrics(
215
242
  # --- Save Precision-Recall Curve ---
216
243
  precision, recall, _ = precision_recall_curve(true_i, prob_i)
217
244
  ap_score = average_precision_score(true_i, prob_i)
218
- fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=100)
219
- ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}')
245
+ fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
246
+ ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=ROC_PR_line)
220
247
  ax_pr.set_title(f'Precision-Recall Curve for "{name}"')
221
248
  ax_pr.set_xlabel('Recall'); ax_pr.set_ylabel('Precision')
222
249
  ax_pr.legend(loc='lower left'); ax_pr.grid(True, linestyle='--', alpha=0.6)
223
250
  pr_path = save_dir_path / f"pr_curve_{sanitized_name}.svg"
224
251
  plt.savefig(pr_path)
225
252
  plt.close(fig_pr)
253
+
254
+ # restore RC params
255
+ plt.rcParams.update(original_rc_params)
226
256
 
227
257
  _LOGGER.info(f"All individual label reports and plots saved to '{save_dir_path.name}'")
228
258
 
ml_tools/ML_trainer.py CHANGED
@@ -13,11 +13,12 @@ from .keys import PyTorchLogKeys, PyTorchCheckpointKeys, DatasetKeys
13
13
  from ._logger import _LOGGER
14
14
  from .path_manager import make_fullpath
15
15
  from .ML_vision_evaluation import segmentation_metrics, object_detection_metrics
16
+ from .ML_configuration import ClassificationMetricsFormat, MultiClassificationMetricsFormat
16
17
 
17
18
 
18
19
  __all__ = [
19
20
  "MLTrainer",
20
- "ObjectDetectionTrainer"
21
+ "ObjectDetectionTrainer",
21
22
  ]
22
23
 
23
24
 
@@ -334,14 +335,16 @@ class MLTrainer:
334
335
 
335
336
  yield y_pred_batch, y_prob_batch, y_true_batch
336
337
 
337
- def evaluate(self, save_dir: Union[str, Path], data: Optional[Union[DataLoader, Dataset]] = None, classification_threshold: float = 0.5):
338
+ def evaluate(self,
339
+ save_dir: Union[str, Path],
340
+ data: Optional[Union[DataLoader, Dataset]] = None,
341
+ format_configuration: Optional[Union[ClassificationMetricsFormat, MultiClassificationMetricsFormat]]=None):
338
342
  """
339
343
  Evaluates the model, routing to the correct evaluation function based on task `kind`.
340
344
 
341
345
  Args:
342
346
  save_dir (str | Path): Directory to save all reports and plots.
343
347
  data (DataLoader | Dataset | None): The data to evaluate on. If None, defaults to the trainer's internal test_dataset.
344
- classification_threshold (float): Probability threshold for multi-label tasks.
345
348
  """
346
349
  dataset_for_names = None
347
350
  eval_loader = None
@@ -376,10 +379,10 @@ class MLTrainer:
376
379
  _LOGGER.error("Cannot evaluate. No valid data was provided or found.")
377
380
  raise ValueError()
378
381
 
379
- print("\n--- Model Evaluation ---")
382
+ # print("\n--- Model Evaluation ---")
380
383
 
381
384
  all_preds, all_probs, all_true = [], [], []
382
- for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader, classification_threshold):
385
+ for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
383
386
  if y_pred_b is not None: all_preds.append(y_pred_b)
384
387
  if y_prob_b is not None: all_probs.append(y_prob_b)
385
388
  if y_true_b is not None: all_true.append(y_true_b)
@@ -397,7 +400,19 @@ class MLTrainer:
397
400
  regression_metrics(y_true.flatten(), y_pred.flatten(), save_dir)
398
401
 
399
402
  elif self.kind == "classification":
400
- classification_metrics(save_dir, y_true, y_pred, y_prob)
403
+ # Parse configuration
404
+ if format_configuration and isinstance(format_configuration, ClassificationMetricsFormat):
405
+ classification_metrics(save_dir=save_dir,
406
+ y_true=y_true,
407
+ y_pred=y_pred,
408
+ y_prob=y_prob,
409
+ cmap=format_configuration.cmap,
410
+ class_map=format_configuration.class_map,
411
+ ROC_PR_line=format_configuration.ROC_PR_line,
412
+ calibration_bins=format_configuration.calibration_bins,
413
+ font_size=format_configuration.font_size)
414
+ else:
415
+ classification_metrics(save_dir, y_true, y_pred, y_prob)
401
416
 
402
417
  elif self.kind == "multi_target_regression":
403
418
  try:
@@ -419,7 +434,18 @@ class MLTrainer:
419
434
  if y_prob is None:
420
435
  _LOGGER.error("Evaluation for multi_label_classification requires probabilities (y_prob).")
421
436
  return
422
- multi_label_classification_metrics(y_true, y_prob, target_names, save_dir, classification_threshold)
437
+
438
+ if format_configuration and isinstance(format_configuration, MultiClassificationMetricsFormat):
439
+ multi_label_classification_metrics(y_true=y_true,
440
+ y_prob=y_prob,
441
+ target_names=target_names,
442
+ save_dir=save_dir,
443
+ threshold=format_configuration.threshold,
444
+ ROC_PR_line=format_configuration.ROC_PR_line,
445
+ cmap=format_configuration.cmap,
446
+ font_size=format_configuration.font_size)
447
+ else:
448
+ multi_label_classification_metrics(y_true, y_prob, target_names, save_dir)
423
449
 
424
450
  elif self.kind == "segmentation":
425
451
  class_names = None
@@ -445,7 +471,7 @@ class MLTrainer:
445
471
 
446
472
  segmentation_metrics(y_true, y_pred, save_dir, class_names=class_names)
447
473
 
448
- print("\n--- Training History ---")
474
+ # print("\n--- Training History ---")
449
475
  plot_losses(self.history, save_dir=save_dir)
450
476
 
451
477
  def explain(self,
@@ -502,7 +528,7 @@ class MLTrainer:
502
528
  rand_indices = torch.randperm(full_data.size(0))[:num_samples]
503
529
  return full_data[rand_indices]
504
530
 
505
- print(f"\n--- Preparing SHAP Data (sampling up to {n_samples} instances) ---")
531
+ # print(f"\n--- Preparing SHAP Data (sampling up to {n_samples} instances) ---")
506
532
 
507
533
  # 1. Get background data from the trainer's train_dataset
508
534
  background_data = _get_random_sample(self.train_dataset, n_samples)
@@ -610,7 +636,7 @@ class MLTrainer:
610
636
  plot_n_features (int): Number of top features to plot.
611
637
  """
612
638
 
613
- print("\n--- Attention Analysis ---")
639
+ # print("\n--- Attention Analysis ---")
614
640
 
615
641
  # --- Step 1: Check if the model supports this explanation ---
616
642
  if not getattr(self.model, 'has_interpretable_attention', False):
@@ -994,7 +1020,7 @@ class ObjectDetectionTrainer:
994
1020
  _LOGGER.error("Cannot evaluate. No valid data was provided or found.")
995
1021
  raise ValueError()
996
1022
 
997
- print("\n--- Model Evaluation ---")
1023
+ # print("\n--- Model Evaluation ---")
998
1024
 
999
1025
  all_predictions = []
1000
1026
  all_targets = []
@@ -1043,7 +1069,7 @@ class ObjectDetectionTrainer:
1043
1069
  print_output=False
1044
1070
  )
1045
1071
 
1046
- print("\n--- Training History ---")
1072
+ # print("\n--- Training History ---")
1047
1073
  plot_losses(self.history, save_dir=save_dir)
1048
1074
 
1049
1075
  def _callbacks_hook(self, method_name: str, *args, **kwargs):
ml_tools/ML_utilities.py CHANGED
@@ -10,6 +10,7 @@ from ._logger import _LOGGER
10
10
  from .keys import DatasetKeys, PytorchModelArchitectureKeys, PytorchArtifactPathKeys, SHAPKeys, UtilityKeys, PyTorchCheckpointKeys
11
11
  from .utilities import load_dataframe
12
12
  from .custom_logger import save_list_strings, custom_logger
13
+ from .serde import serialize_object_filename
13
14
 
14
15
 
15
16
  __all__ = [
@@ -18,7 +19,8 @@ __all__ = [
18
19
  "get_model_parameters",
19
20
  "inspect_model_architecture",
20
21
  "inspect_pth_file",
21
- "set_parameter_requires_grad"
22
+ "set_parameter_requires_grad",
23
+ "save_pretrained_transforms"
22
24
  ]
23
25
 
24
26
 
@@ -524,5 +526,52 @@ def _set_params_grad(
524
526
  return params_changed
525
527
 
526
528
 
529
+ def save_pretrained_transforms(model: nn.Module, output_dir: Union[str, Path]):
530
+ """
531
+ Checks a model for the 'self._pretrained_default_transforms' attribute, if found,
532
+ serializes the returned transform object as a .joblib file.
533
+
534
+ This saves the callable transform object itself for
535
+ later use, such as passing it directly to the 'transform_source'
536
+ argument of the PyTorchVisionInferenceHandler.
537
+
538
+ Args:
539
+ model (nn.Module): The model instance to check.
540
+ output_dir (str | Path): The directory where the transform file will be saved.
541
+ """
542
+ output_filename = "pretrained_model_transformations"
543
+
544
+ # 1. Check for the "secret attribute"
545
+ if not hasattr(model, '_pretrained_default_transforms'):
546
+ _LOGGER.warning(f"Model of type {type(model).__name__} does not have the required attribute. No transformations saved.")
547
+ return
548
+
549
+ # 2. Get the transform object
550
+ try:
551
+ transform_obj = model._pretrained_default_transforms
552
+ except Exception as e:
553
+ _LOGGER.error(f"Error calling the required attribute on model: {e}")
554
+ return
555
+
556
+ # 3. Check if the object is actually there
557
+ if transform_obj is None:
558
+ _LOGGER.warning(f"Model {type(model).__name__} has the required attribute but returned None. No transforms saved.")
559
+ return
560
+
561
+ # 4. Serialize and save using serde
562
+ try:
563
+ serialize_object_filename(
564
+ obj=transform_obj,
565
+ save_dir=output_dir,
566
+ filename=output_filename,
567
+ verbose=True,
568
+ raise_on_error=True
569
+ )
570
+ # _LOGGER.info(f"Successfully saved pretrained transforms to '{output_dir}'.")
571
+ except Exception as e:
572
+ _LOGGER.error(f"Failed to serialize transformations: {e}")
573
+ raise
574
+
575
+
527
576
  def info():
528
577
  _script_info(__all__)