dragon-ml-toolbox 20.2.0__py3-none-any.whl → 20.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.4.0.dist-info}/METADATA +1 -1
  2. dragon_ml_toolbox-20.4.0.dist-info/RECORD +143 -0
  3. ml_tools/ETL_cleaning/__init__.py +5 -1
  4. ml_tools/ETL_cleaning/_basic_clean.py +1 -1
  5. ml_tools/ETL_engineering/__init__.py +5 -1
  6. ml_tools/GUI_tools/__init__.py +5 -1
  7. ml_tools/IO_tools/_IO_loggers.py +33 -21
  8. ml_tools/IO_tools/__init__.py +5 -1
  9. ml_tools/MICE/__init__.py +8 -2
  10. ml_tools/MICE/_dragon_mice.py +1 -1
  11. ml_tools/ML_callbacks/__init__.py +5 -1
  12. ml_tools/ML_chain/__init__.py +5 -1
  13. ml_tools/ML_configuration/__init__.py +7 -1
  14. ml_tools/ML_configuration/_training.py +65 -1
  15. ml_tools/ML_datasetmaster/__init__.py +5 -1
  16. ml_tools/ML_datasetmaster/_base_datasetmaster.py +31 -20
  17. ml_tools/ML_datasetmaster/_datasetmaster.py +26 -9
  18. ml_tools/ML_datasetmaster/_sequence_datasetmaster.py +38 -23
  19. ml_tools/ML_evaluation/__init__.py +5 -1
  20. ml_tools/ML_evaluation/_classification.py +10 -2
  21. ml_tools/ML_evaluation_captum/__init__.py +5 -1
  22. ml_tools/ML_finalize_handler/__init__.py +5 -1
  23. ml_tools/ML_inference/__init__.py +5 -1
  24. ml_tools/ML_inference_sequence/__init__.py +5 -1
  25. ml_tools/ML_inference_vision/__init__.py +5 -1
  26. ml_tools/ML_models/__init__.py +21 -6
  27. ml_tools/ML_models/_dragon_autoint.py +302 -0
  28. ml_tools/ML_models/_dragon_gate.py +358 -0
  29. ml_tools/ML_models/_dragon_node.py +268 -0
  30. ml_tools/ML_models/_dragon_tabnet.py +255 -0
  31. ml_tools/ML_models_sequence/__init__.py +5 -1
  32. ml_tools/ML_models_vision/__init__.py +5 -1
  33. ml_tools/ML_optimization/__init__.py +11 -3
  34. ml_tools/ML_optimization/_multi_dragon.py +24 -8
  35. ml_tools/ML_optimization/_single_dragon.py +47 -67
  36. ml_tools/ML_optimization/_single_manual.py +1 -1
  37. ml_tools/ML_scaler/_ML_scaler.py +12 -7
  38. ml_tools/ML_scaler/__init__.py +5 -1
  39. ml_tools/ML_trainer/__init__.py +5 -1
  40. ml_tools/ML_trainer/_base_trainer.py +136 -13
  41. ml_tools/ML_trainer/_dragon_detection_trainer.py +31 -91
  42. ml_tools/ML_trainer/_dragon_sequence_trainer.py +24 -74
  43. ml_tools/ML_trainer/_dragon_trainer.py +24 -85
  44. ml_tools/ML_utilities/__init__.py +5 -1
  45. ml_tools/ML_utilities/_inspection.py +44 -30
  46. ml_tools/ML_vision_transformers/__init__.py +8 -2
  47. ml_tools/PSO_optimization/__init__.py +5 -1
  48. ml_tools/SQL/__init__.py +8 -2
  49. ml_tools/VIF/__init__.py +5 -1
  50. ml_tools/data_exploration/__init__.py +4 -1
  51. ml_tools/data_exploration/_cleaning.py +4 -2
  52. ml_tools/ensemble_evaluation/__init__.py +5 -1
  53. ml_tools/ensemble_inference/__init__.py +5 -1
  54. ml_tools/ensemble_learning/__init__.py +5 -1
  55. ml_tools/excel_handler/__init__.py +5 -1
  56. ml_tools/keys/__init__.py +5 -1
  57. ml_tools/keys/_keys.py +1 -1
  58. ml_tools/math_utilities/__init__.py +5 -1
  59. ml_tools/optimization_tools/__init__.py +5 -1
  60. ml_tools/path_manager/__init__.py +8 -2
  61. ml_tools/plot_fonts/__init__.py +8 -2
  62. ml_tools/schema/__init__.py +8 -2
  63. ml_tools/schema/_feature_schema.py +3 -3
  64. ml_tools/serde/__init__.py +5 -1
  65. ml_tools/utilities/__init__.py +5 -1
  66. ml_tools/utilities/_utility_save_load.py +38 -20
  67. dragon_ml_toolbox-20.2.0.dist-info/RECORD +0 -179
  68. ml_tools/ETL_cleaning/_imprimir.py +0 -13
  69. ml_tools/ETL_engineering/_imprimir.py +0 -24
  70. ml_tools/GUI_tools/_imprimir.py +0 -12
  71. ml_tools/IO_tools/_imprimir.py +0 -14
  72. ml_tools/MICE/_imprimir.py +0 -11
  73. ml_tools/ML_callbacks/_imprimir.py +0 -12
  74. ml_tools/ML_chain/_imprimir.py +0 -12
  75. ml_tools/ML_configuration/_imprimir.py +0 -47
  76. ml_tools/ML_datasetmaster/_imprimir.py +0 -15
  77. ml_tools/ML_evaluation/_imprimir.py +0 -25
  78. ml_tools/ML_evaluation_captum/_imprimir.py +0 -10
  79. ml_tools/ML_finalize_handler/_imprimir.py +0 -8
  80. ml_tools/ML_inference/_imprimir.py +0 -11
  81. ml_tools/ML_inference_sequence/_imprimir.py +0 -8
  82. ml_tools/ML_inference_vision/_imprimir.py +0 -8
  83. ml_tools/ML_models/_advanced_models.py +0 -1086
  84. ml_tools/ML_models/_imprimir.py +0 -18
  85. ml_tools/ML_models_sequence/_imprimir.py +0 -8
  86. ml_tools/ML_models_vision/_imprimir.py +0 -16
  87. ml_tools/ML_optimization/_imprimir.py +0 -13
  88. ml_tools/ML_scaler/_imprimir.py +0 -8
  89. ml_tools/ML_trainer/_imprimir.py +0 -10
  90. ml_tools/ML_utilities/_imprimir.py +0 -16
  91. ml_tools/ML_vision_transformers/_imprimir.py +0 -14
  92. ml_tools/PSO_optimization/_imprimir.py +0 -10
  93. ml_tools/SQL/_imprimir.py +0 -8
  94. ml_tools/VIF/_imprimir.py +0 -10
  95. ml_tools/data_exploration/_imprimir.py +0 -32
  96. ml_tools/ensemble_evaluation/_imprimir.py +0 -14
  97. ml_tools/ensemble_inference/_imprimir.py +0 -9
  98. ml_tools/ensemble_learning/_imprimir.py +0 -10
  99. ml_tools/excel_handler/_imprimir.py +0 -13
  100. ml_tools/keys/_imprimir.py +0 -11
  101. ml_tools/math_utilities/_imprimir.py +0 -11
  102. ml_tools/optimization_tools/_imprimir.py +0 -13
  103. ml_tools/path_manager/_imprimir.py +0 -15
  104. ml_tools/plot_fonts/_imprimir.py +0 -8
  105. ml_tools/schema/_imprimir.py +0 -10
  106. ml_tools/serde/_imprimir.py +0 -10
  107. ml_tools/utilities/_imprimir.py +0 -18
  108. {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.4.0.dist-info}/WHEEL +0 -0
  109. {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.4.0.dist-info}/licenses/LICENSE +0 -0
  110. {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.4.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  111. {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.4.0.dist-info}/top_level.txt +0 -0
@@ -80,26 +80,12 @@ class DragonDetectionTrainer(_BaseDragonTrainer):
80
80
 
81
81
  def _create_dataloaders(self, batch_size: int, shuffle: bool):
82
82
  """Initializes the DataLoaders with the object detection collate_fn."""
83
- # Ensure stability on MPS devices by setting num_workers to 0
84
- loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
85
-
86
- self.train_loader = DataLoader(
87
- dataset=self.train_dataset,
88
- batch_size=batch_size,
89
- shuffle=shuffle,
90
- num_workers=loader_workers,
91
- pin_memory=("cuda" in self.device.type),
92
- collate_fn=self.collate_fn, # Use the provided collate function
93
- drop_last=True
94
- )
95
-
96
- self.validation_loader = DataLoader(
97
- dataset=self.validation_dataset,
98
- batch_size=batch_size,
99
- shuffle=False,
100
- num_workers=loader_workers,
101
- pin_memory=("cuda" in self.device.type),
102
- collate_fn=self.collate_fn # Use the provided collate function
83
+ self._make_dataloaders(
84
+ train_dataset=self.train_dataset,
85
+ validation_dataset=self.validation_dataset,
86
+ batch_size=batch_size,
87
+ shuffle=shuffle,
88
+ collate_fn=self.collate_fn
103
89
  )
104
90
 
105
91
  def _train_step(self):
@@ -207,17 +193,9 @@ class DragonDetectionTrainer(_BaseDragonTrainer):
207
193
  - If 'current', use the current state of the trained model up the latest trained epoch.
208
194
  test_data (DataLoader | Dataset | None): Optional Test data to evaluate the model performance. Validation and Test metrics will be saved to subdirectories.
209
195
  """
210
- # Validate model checkpoint
211
- if isinstance(model_checkpoint, Path):
212
- checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
213
- elif model_checkpoint in [MagicWords.BEST, MagicWords.CURRENT]:
214
- checkpoint_validated = model_checkpoint
215
- else:
216
- _LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.BEST}', or the string '{MagicWords.CURRENT}'.")
217
- raise ValueError()
218
-
219
- # Validate directory
220
- save_path = make_fullpath(save_dir, make=True, enforce="directory")
196
+ # Validate inputs using base helpers
197
+ checkpoint_validated = self._validate_checkpoint_arg(model_checkpoint)
198
+ save_path = self._validate_save_dir(save_dir)
221
199
 
222
200
  # Validate test data and dispatch
223
201
  if test_data is not None:
@@ -230,21 +208,21 @@ class DragonDetectionTrainer(_BaseDragonTrainer):
230
208
  test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
231
209
 
232
210
  # Dispatch validation set
233
- _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
211
+ _LOGGER.info(f"🔎 Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
234
212
  self._evaluate(save_dir=validation_metrics_path,
235
- model_checkpoint=checkpoint_validated,
213
+ model_checkpoint=checkpoint_validated, # type: ignore
236
214
  data=None) # 'None' triggers use of self.test_dataset
237
215
 
238
216
  # Dispatch test set
239
- _LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
217
+ _LOGGER.info(f"🔎 Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
240
218
  self._evaluate(save_dir=test_metrics_path,
241
219
  model_checkpoint="current", # Use 'current' state after loading checkpoint once
242
220
  data=test_data_validated)
243
221
  else:
244
222
  # Dispatch validation set
245
- _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
223
+ _LOGGER.info(f"🔎 Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
246
224
  self._evaluate(save_dir=save_path,
247
- model_checkpoint=checkpoint_validated,
225
+ model_checkpoint=checkpoint_validated, # type: ignore
248
226
  data=None) # 'None' triggers use of self.test_dataset
249
227
 
250
228
  def _evaluate(self,
@@ -263,54 +241,17 @@ class DragonDetectionTrainer(_BaseDragonTrainer):
263
241
  - If 'best', the best checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
264
242
  - If 'current', use the current state of the trained model up the latest trained epoch.
265
243
  """
266
- dataset_for_artifacts = None
267
- eval_loader = None
268
-
269
244
  # load model checkpoint
270
- if isinstance(model_checkpoint, Path):
271
- self._load_checkpoint(path=model_checkpoint)
272
- elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
273
- path_to_latest = self._checkpoint_callback.best_checkpoint_path
274
- self._load_checkpoint(path_to_latest)
275
- elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
276
- _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
277
- raise ValueError()
278
-
279
- # Dataloader
280
- if isinstance(data, DataLoader):
281
- eval_loader = data
282
- if hasattr(data, 'dataset'):
283
- dataset_for_artifacts = data.dataset # type: ignore
284
- elif isinstance(data, Dataset):
285
- # Create a new loader from the provided dataset
286
- eval_loader = DataLoader(data,
287
- batch_size=self._batch_size,
288
- shuffle=False,
289
- num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
290
- pin_memory=(self.device.type == "cuda"),
291
- collate_fn=self.collate_fn)
292
- dataset_for_artifacts = data
293
- else: # data is None, use the trainer's default test dataset
294
- if self.validation_dataset is None:
295
- _LOGGER.error("Cannot evaluate. No data provided and no test_dataset available in the trainer.")
296
- raise ValueError()
297
- # Create a fresh DataLoader from the test_dataset
298
- eval_loader = DataLoader(
299
- self.validation_dataset,
300
- batch_size=self._batch_size,
301
- shuffle=False,
302
- num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
303
- pin_memory=(self.device.type == "cuda"),
304
- collate_fn=self.collate_fn
305
- )
306
- dataset_for_artifacts = self.validation_dataset
307
-
308
- if eval_loader is None:
309
- _LOGGER.error("Cannot evaluate. No valid data was provided or found.")
310
- raise ValueError()
311
-
312
- # print("\n--- Model Evaluation ---")
313
-
245
+ self._load_model_state_wrapper(model_checkpoint)
246
+
247
+ # Prepare Data using Base Helper
248
+ eval_loader, dataset_for_artifacts = self._prepare_eval_data(
249
+ data,
250
+ self.validation_dataset,
251
+ collate_fn=self.collate_fn # Important for Detection
252
+ )
253
+
254
+ # Gather all predictions and targets
314
255
  all_predictions = []
315
256
  all_targets = []
316
257
 
@@ -380,12 +321,8 @@ class DragonDetectionTrainer(_BaseDragonTrainer):
380
321
  _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeObjectDetection', but got {type(finalize_config).__name__}.")
381
322
  raise TypeError()
382
323
 
383
- # handle save path
384
- dir_path = make_fullpath(save_dir, make=True, enforce="directory")
385
- full_path = dir_path / finalize_config.filename
386
-
387
324
  # handle checkpoint
388
- self._load_model_state_for_finalizing(model_checkpoint)
325
+ self._load_model_state_wrapper(model_checkpoint)
389
326
 
390
327
  # Create finalized data
391
328
  finalized_data = {
@@ -397,6 +334,9 @@ class DragonDetectionTrainer(_BaseDragonTrainer):
397
334
  if finalize_config.class_map is not None:
398
335
  finalized_data[PyTorchCheckpointKeys.CLASS_MAP] = finalize_config.class_map
399
336
 
400
- torch.save(finalized_data, full_path)
401
-
402
- _LOGGER.info(f"Finalized model file saved to '{full_path}'")
337
+ # Save using base helper
338
+ self._save_finalized_artifact(
339
+ finalized_data=finalized_data,
340
+ save_dir=save_dir,
341
+ filename=finalize_config.filename
342
+ )
@@ -99,23 +99,11 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
99
99
  def _create_dataloaders(self, batch_size: int, shuffle: bool):
100
100
  """Initializes the DataLoaders."""
101
101
  # Ensure stability on MPS devices by setting num_workers to 0
102
- loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
103
-
104
- self.train_loader = DataLoader(
105
- dataset=self.train_dataset,
106
- batch_size=batch_size,
107
- shuffle=shuffle,
108
- num_workers=loader_workers,
109
- pin_memory=("cuda" in self.device.type),
110
- drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
111
- )
112
-
113
- self.validation_loader = DataLoader(
114
- dataset=self.validation_dataset,
115
- batch_size=batch_size,
116
- shuffle=False,
117
- num_workers=loader_workers,
118
- pin_memory=("cuda" in self.device.type)
102
+ self._make_dataloaders(
103
+ train_dataset=self.train_dataset,
104
+ validation_dataset=self.validation_dataset,
105
+ batch_size=batch_size,
106
+ shuffle=shuffle
119
107
  )
120
108
 
121
109
  def _train_step(self):
@@ -279,14 +267,9 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
279
267
  val_format_configuration: Optional configuration for validation metrics.
280
268
  test_format_configuration: Optional configuration for test metrics.
281
269
  """
282
- # Validate model checkpoint
283
- if isinstance(model_checkpoint, Path):
284
- checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
285
- elif model_checkpoint in [MagicWords.BEST, MagicWords.CURRENT]:
286
- checkpoint_validated = model_checkpoint
287
- else:
288
- _LOGGER.error(f"'model_checkpoint' must be a Path object, or '{MagicWords.BEST}', or '{MagicWords.CURRENT}'.")
289
- raise ValueError()
270
+ # Validate inputs using base helpers
271
+ checkpoint_validated = self._validate_checkpoint_arg(model_checkpoint)
272
+ save_path = self._validate_save_dir(save_dir)
290
273
 
291
274
  # Validate val configuration
292
275
  if val_format_configuration is not None:
@@ -294,9 +277,6 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
294
277
  _LOGGER.error(f"Invalid 'val_format_configuration': '{type(val_format_configuration)}'.")
295
278
  raise ValueError()
296
279
 
297
- # Validate directory
298
- save_path = make_fullpath(save_dir, make=True, enforce="directory")
299
-
300
280
  # Validate test data and dispatch
301
281
  if test_data is not None:
302
282
  if not isinstance(test_data, (DataLoader, Dataset)):
@@ -308,9 +288,9 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
308
288
  test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
309
289
 
310
290
  # Dispatch validation set
311
- _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
291
+ _LOGGER.info(f"🔎 Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
312
292
  self._evaluate(save_dir=validation_metrics_path,
313
- model_checkpoint=checkpoint_validated,
293
+ model_checkpoint=checkpoint_validated, # type: ignore
314
294
  data=None,
315
295
  format_configuration=val_format_configuration)
316
296
 
@@ -329,16 +309,16 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
329
309
  test_configuration_validated = test_format_configuration
330
310
 
331
311
  # Dispatch test set
332
- _LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
312
+ _LOGGER.info(f"🔎 Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
333
313
  self._evaluate(save_dir=test_metrics_path,
334
314
  model_checkpoint="current",
335
315
  data=test_data_validated,
336
316
  format_configuration=test_configuration_validated)
337
317
  else:
338
318
  # Dispatch validation set
339
- _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
319
+ _LOGGER.info(f"🔎 Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
340
320
  self._evaluate(save_dir=save_path,
341
- model_checkpoint=checkpoint_validated,
321
+ model_checkpoint=checkpoint_validated, # type: ignore
342
322
  data=None,
343
323
  format_configuration=val_format_configuration)
344
324
 
@@ -350,42 +330,13 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
350
330
  """
351
331
  Private evaluation helper.
352
332
  """
353
- eval_loader = None
354
-
355
333
  # load model checkpoint
356
- if isinstance(model_checkpoint, Path):
357
- self._load_checkpoint(path=model_checkpoint)
358
- elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
359
- path_to_latest = self._checkpoint_callback.best_checkpoint_path
360
- self._load_checkpoint(path_to_latest)
361
- elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
362
- _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
363
- raise ValueError()
334
+ self._load_model_state_wrapper(model_checkpoint)
364
335
 
365
- # Dataloader
366
- if isinstance(data, DataLoader):
367
- eval_loader = data
368
- elif isinstance(data, Dataset):
369
- # Create a new loader from the provided dataset
370
- eval_loader = DataLoader(data,
371
- batch_size=self._batch_size,
372
- shuffle=False,
373
- num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
374
- pin_memory=(self.device.type == "cuda"))
375
- else: # data is None, use the trainer's default validation dataset
376
- if self.validation_dataset is None:
377
- _LOGGER.error("Cannot evaluate. No data provided and no validation_dataset available in the trainer.")
378
- raise ValueError()
379
- eval_loader = DataLoader(self.validation_dataset,
380
- batch_size=self._batch_size,
381
- shuffle=False,
382
- num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
383
- pin_memory=(self.device.type == "cuda"))
384
-
385
- if eval_loader is None:
386
- _LOGGER.error("Cannot evaluate. No valid data was provided or found.")
387
- raise ValueError()
336
+ # Prepare Data using Base Helper
337
+ eval_loader, _ = self._prepare_eval_data(data, self.validation_dataset)
388
338
 
339
+ # Gather Predictions
389
340
  all_preds, _, all_true = [], [], []
390
341
  for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
391
342
  if y_pred_b is not None: all_preds.append(y_pred_b)
@@ -514,13 +465,9 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
514
465
  elif self.kind == MLTaskKeys.SEQUENCE_VALUE and not isinstance(finalize_config, FinalizeSequenceValuePrediction):
515
466
  _LOGGER.error(f"Received a wrong finalize configuration for task {self.kind}: {type(finalize_config).__name__}.")
516
467
  raise TypeError()
517
-
518
- # handle save path
519
- dir_path = make_fullpath(save_dir, make=True, enforce="directory")
520
- full_path = dir_path / finalize_config.filename
521
468
 
522
469
  # handle checkpoint
523
- self._load_model_state_for_finalizing(model_checkpoint)
470
+ self._load_model_state_wrapper(model_checkpoint)
524
471
 
525
472
  # Create finalized data
526
473
  finalized_data = {
@@ -534,7 +481,10 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
534
481
  if finalize_config.initial_sequence is not None:
535
482
  finalized_data[PyTorchCheckpointKeys.INITIAL_SEQUENCE] = finalize_config.initial_sequence
536
483
 
537
- torch.save(finalized_data, full_path)
538
-
539
- _LOGGER.info(f"Finalized model file saved to '{full_path}'")
484
+ # Save using base helper
485
+ self._save_finalized_artifact(
486
+ finalized_data=finalized_data,
487
+ save_dir=save_dir,
488
+ filename=finalize_config.filename
489
+ )
540
490
 
@@ -142,23 +142,11 @@ class DragonTrainer(_BaseDragonTrainer):
142
142
  def _create_dataloaders(self, batch_size: int, shuffle: bool):
143
143
  """Initializes the DataLoaders."""
144
144
  # Ensure stability on MPS devices by setting num_workers to 0
145
- loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
146
-
147
- self.train_loader = DataLoader(
148
- dataset=self.train_dataset,
149
- batch_size=batch_size,
150
- shuffle=shuffle,
151
- num_workers=loader_workers,
152
- pin_memory=("cuda" in self.device.type),
153
- drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
154
- )
155
-
156
- self.validation_loader = DataLoader(
157
- dataset=self.validation_dataset,
158
- batch_size=batch_size,
159
- shuffle=False,
160
- num_workers=loader_workers,
161
- pin_memory=("cuda" in self.device.type)
145
+ self._make_dataloaders(
146
+ train_dataset=self.train_dataset,
147
+ validation_dataset=self.validation_dataset,
148
+ batch_size=batch_size,
149
+ shuffle=shuffle
162
150
  )
163
151
 
164
152
  def _train_step(self):
@@ -403,14 +391,9 @@ class DragonTrainer(_BaseDragonTrainer):
403
391
  val_format_configuration (object): Optional configuration for metric format output for the validation set.
404
392
  test_format_configuration (object): Optional configuration for metric format output for the test set.
405
393
  """
406
- # Validate model checkpoint
407
- if isinstance(model_checkpoint, Path):
408
- checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
409
- elif model_checkpoint in [MagicWords.BEST, MagicWords.CURRENT]:
410
- checkpoint_validated = model_checkpoint
411
- else:
412
- _LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.BEST}', or the string '{MagicWords.CURRENT}'.")
413
- raise ValueError()
394
+ # Validate inputs using base helpers
395
+ checkpoint_validated = self._validate_checkpoint_arg(model_checkpoint)
396
+ save_path = self._validate_save_dir(save_dir)
414
397
 
415
398
  # Validate classification threshold
416
399
  if self.kind not in MLTaskKeys.ALL_BINARY_TASKS:
@@ -445,9 +428,6 @@ class DragonTrainer(_BaseDragonTrainer):
445
428
  else: # config is None
446
429
  val_configuration_validated = None
447
430
 
448
- # Validate directory
449
- save_path = make_fullpath(save_dir, make=True, enforce="directory")
450
-
451
431
  # Validate test data and dispatch
452
432
  if test_data is not None:
453
433
  if not isinstance(test_data, (DataLoader, Dataset)):
@@ -461,7 +441,7 @@ class DragonTrainer(_BaseDragonTrainer):
461
441
  # Dispatch validation set
462
442
  _LOGGER.info(f"🔎 Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
463
443
  self._evaluate(save_dir=validation_metrics_path,
464
- model_checkpoint=checkpoint_validated,
444
+ model_checkpoint=checkpoint_validated, # type: ignore
465
445
  classification_threshold=threshold_validated,
466
446
  data=None,
467
447
  format_configuration=val_configuration_validated)
@@ -499,9 +479,9 @@ class DragonTrainer(_BaseDragonTrainer):
499
479
  format_configuration=test_configuration_validated)
500
480
  else:
501
481
  # Dispatch validation set
502
- _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
482
+ _LOGGER.info(f"🔎 Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
503
483
  self._evaluate(save_dir=save_path,
504
- model_checkpoint=checkpoint_validated,
484
+ model_checkpoint=checkpoint_validated, # type: ignore
505
485
  classification_threshold=threshold_validated,
506
486
  data=None,
507
487
  format_configuration=val_configuration_validated)
@@ -525,55 +505,16 @@ class DragonTrainer(_BaseDragonTrainer):
525
505
  """
526
506
  Changed to a private helper function.
527
507
  """
528
- dataset_for_artifacts = None
529
- eval_loader = None
530
-
531
508
  # set threshold
532
509
  self._classification_threshold = classification_threshold
533
510
 
534
511
  # load model checkpoint
535
- if isinstance(model_checkpoint, Path):
536
- self._load_checkpoint(path=model_checkpoint)
537
- elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
538
- path_to_latest = self._checkpoint_callback.best_checkpoint_path
539
- self._load_checkpoint(path_to_latest)
540
- elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
541
- _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
542
- raise ValueError()
512
+ self._load_model_state_wrapper(model_checkpoint)
543
513
 
544
- # Dataloader
545
- if isinstance(data, DataLoader):
546
- eval_loader = data
547
- # Try to get the dataset from the loader for fetching target names
548
- if hasattr(data, 'dataset'):
549
- dataset_for_artifacts = data.dataset # type: ignore
550
- elif isinstance(data, Dataset):
551
- # Create a new loader from the provided dataset
552
- eval_loader = DataLoader(data,
553
- batch_size=self._batch_size,
554
- shuffle=False,
555
- num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
556
- pin_memory=(self.device.type == "cuda"))
557
- dataset_for_artifacts = data
558
- else: # data is None, use the trainer's default test dataset
559
- if self.validation_dataset is None:
560
- _LOGGER.error("Cannot evaluate. No data provided and no validation dataset available in the trainer.")
561
- raise ValueError()
562
- # Create a fresh DataLoader from the test_dataset
563
- eval_loader = DataLoader(self.validation_dataset,
564
- batch_size=self._batch_size,
565
- shuffle=False,
566
- num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
567
- pin_memory=(self.device.type == "cuda"))
568
-
569
- dataset_for_artifacts = self.validation_dataset
570
-
571
- if eval_loader is None:
572
- _LOGGER.error("Cannot evaluate. No valid data was provided or found.")
573
- raise ValueError()
574
-
575
- # print("\n--- Model Evaluation ---")
576
-
514
+ # Prepare Data using Base Helper
515
+ eval_loader, dataset_for_artifacts = self._prepare_eval_data(data, self.validation_dataset)
516
+
517
+ # Gather Predictions
577
518
  all_preds, all_probs, all_true = [], [], []
578
519
  for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
579
520
  if y_pred_b is not None: all_preds.append(y_pred_b)
@@ -1128,13 +1069,9 @@ class DragonTrainer(_BaseDragonTrainer):
1128
1069
  elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION and not isinstance(finalize_config, FinalizeMultiLabelBinaryClassification):
1129
1070
  _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiLabelBinaryClassification', but got {type(finalize_config).__name__}.")
1130
1071
  raise TypeError()
1131
-
1132
- # handle save path
1133
- dir_path = make_fullpath(save_dir, make=True, enforce="directory")
1134
- full_path = dir_path / finalize_config.filename
1135
-
1072
+
1136
1073
  # handle checkpoint
1137
- self._load_model_state_for_finalizing(model_checkpoint)
1074
+ self._load_model_state_wrapper(model_checkpoint)
1138
1075
 
1139
1076
  # Create finalized data
1140
1077
  finalized_data = {
@@ -1153,8 +1090,10 @@ class DragonTrainer(_BaseDragonTrainer):
1153
1090
  if finalize_config.class_map is not None:
1154
1091
  finalized_data[PyTorchCheckpointKeys.CLASS_MAP] = finalize_config.class_map
1155
1092
 
1156
- # Save model file
1157
- torch.save(finalized_data, full_path)
1158
-
1159
- _LOGGER.info(f"Finalized model file saved to '{full_path}'")
1093
+ # Save model file using base helper
1094
+ self._save_finalized_artifact(
1095
+ finalized_data=finalized_data,
1096
+ save_dir=save_dir,
1097
+ filename=finalize_config.filename
1098
+ )
1160
1099
 
@@ -16,7 +16,7 @@ from ._train_tools import (
16
16
  save_pretrained_transforms,
17
17
  )
18
18
 
19
- from ._imprimir import info
19
+ from .._core import _imprimir_disponibles
20
20
 
21
21
 
22
22
  __all__ = [
@@ -30,3 +30,7 @@ __all__ = [
30
30
  "save_pretrained_transforms",
31
31
  "select_features_by_shap"
32
32
  ]
33
+
34
+
35
+ def info():
36
+ _imprimir_disponibles(__all__)