dragon-ml-toolbox 20.2.0__py3-none-any.whl → 20.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.3.0.dist-info}/METADATA +1 -1
- dragon_ml_toolbox-20.3.0.dist-info/RECORD +143 -0
- ml_tools/ETL_cleaning/__init__.py +5 -1
- ml_tools/ETL_cleaning/_basic_clean.py +1 -1
- ml_tools/ETL_engineering/__init__.py +5 -1
- ml_tools/GUI_tools/__init__.py +5 -1
- ml_tools/IO_tools/_IO_loggers.py +12 -4
- ml_tools/IO_tools/__init__.py +5 -1
- ml_tools/MICE/__init__.py +8 -2
- ml_tools/MICE/_dragon_mice.py +1 -1
- ml_tools/ML_callbacks/__init__.py +5 -1
- ml_tools/ML_chain/__init__.py +5 -1
- ml_tools/ML_configuration/__init__.py +7 -1
- ml_tools/ML_configuration/_training.py +65 -1
- ml_tools/ML_datasetmaster/__init__.py +5 -1
- ml_tools/ML_datasetmaster/_base_datasetmaster.py +31 -20
- ml_tools/ML_datasetmaster/_datasetmaster.py +26 -9
- ml_tools/ML_datasetmaster/_sequence_datasetmaster.py +38 -23
- ml_tools/ML_evaluation/__init__.py +5 -1
- ml_tools/ML_evaluation_captum/__init__.py +5 -1
- ml_tools/ML_finalize_handler/__init__.py +5 -1
- ml_tools/ML_inference/__init__.py +5 -1
- ml_tools/ML_inference_sequence/__init__.py +5 -1
- ml_tools/ML_inference_vision/__init__.py +5 -1
- ml_tools/ML_models/__init__.py +21 -6
- ml_tools/ML_models/_dragon_autoint.py +302 -0
- ml_tools/ML_models/_dragon_gate.py +358 -0
- ml_tools/ML_models/_dragon_node.py +268 -0
- ml_tools/ML_models/_dragon_tabnet.py +255 -0
- ml_tools/ML_models_sequence/__init__.py +5 -1
- ml_tools/ML_models_vision/__init__.py +5 -1
- ml_tools/ML_optimization/__init__.py +11 -3
- ml_tools/ML_optimization/_multi_dragon.py +2 -2
- ml_tools/ML_optimization/_single_dragon.py +47 -67
- ml_tools/ML_optimization/_single_manual.py +1 -1
- ml_tools/ML_scaler/_ML_scaler.py +12 -7
- ml_tools/ML_scaler/__init__.py +5 -1
- ml_tools/ML_trainer/__init__.py +5 -1
- ml_tools/ML_trainer/_base_trainer.py +136 -13
- ml_tools/ML_trainer/_dragon_detection_trainer.py +31 -91
- ml_tools/ML_trainer/_dragon_sequence_trainer.py +24 -74
- ml_tools/ML_trainer/_dragon_trainer.py +24 -85
- ml_tools/ML_utilities/__init__.py +5 -1
- ml_tools/ML_utilities/_inspection.py +44 -30
- ml_tools/ML_vision_transformers/__init__.py +8 -2
- ml_tools/PSO_optimization/__init__.py +5 -1
- ml_tools/SQL/__init__.py +8 -2
- ml_tools/VIF/__init__.py +5 -1
- ml_tools/data_exploration/__init__.py +4 -1
- ml_tools/data_exploration/_cleaning.py +4 -2
- ml_tools/ensemble_evaluation/__init__.py +5 -1
- ml_tools/ensemble_inference/__init__.py +5 -1
- ml_tools/ensemble_learning/__init__.py +5 -1
- ml_tools/excel_handler/__init__.py +5 -1
- ml_tools/keys/__init__.py +5 -1
- ml_tools/math_utilities/__init__.py +5 -1
- ml_tools/optimization_tools/__init__.py +5 -1
- ml_tools/path_manager/__init__.py +8 -2
- ml_tools/plot_fonts/__init__.py +8 -2
- ml_tools/schema/__init__.py +8 -2
- ml_tools/schema/_feature_schema.py +3 -3
- ml_tools/serde/__init__.py +5 -1
- ml_tools/utilities/__init__.py +5 -1
- ml_tools/utilities/_utility_save_load.py +38 -20
- dragon_ml_toolbox-20.2.0.dist-info/RECORD +0 -179
- ml_tools/ETL_cleaning/_imprimir.py +0 -13
- ml_tools/ETL_engineering/_imprimir.py +0 -24
- ml_tools/GUI_tools/_imprimir.py +0 -12
- ml_tools/IO_tools/_imprimir.py +0 -14
- ml_tools/MICE/_imprimir.py +0 -11
- ml_tools/ML_callbacks/_imprimir.py +0 -12
- ml_tools/ML_chain/_imprimir.py +0 -12
- ml_tools/ML_configuration/_imprimir.py +0 -47
- ml_tools/ML_datasetmaster/_imprimir.py +0 -15
- ml_tools/ML_evaluation/_imprimir.py +0 -25
- ml_tools/ML_evaluation_captum/_imprimir.py +0 -10
- ml_tools/ML_finalize_handler/_imprimir.py +0 -8
- ml_tools/ML_inference/_imprimir.py +0 -11
- ml_tools/ML_inference_sequence/_imprimir.py +0 -8
- ml_tools/ML_inference_vision/_imprimir.py +0 -8
- ml_tools/ML_models/_advanced_models.py +0 -1086
- ml_tools/ML_models/_imprimir.py +0 -18
- ml_tools/ML_models_sequence/_imprimir.py +0 -8
- ml_tools/ML_models_vision/_imprimir.py +0 -16
- ml_tools/ML_optimization/_imprimir.py +0 -13
- ml_tools/ML_scaler/_imprimir.py +0 -8
- ml_tools/ML_trainer/_imprimir.py +0 -10
- ml_tools/ML_utilities/_imprimir.py +0 -16
- ml_tools/ML_vision_transformers/_imprimir.py +0 -14
- ml_tools/PSO_optimization/_imprimir.py +0 -10
- ml_tools/SQL/_imprimir.py +0 -8
- ml_tools/VIF/_imprimir.py +0 -10
- ml_tools/data_exploration/_imprimir.py +0 -32
- ml_tools/ensemble_evaluation/_imprimir.py +0 -14
- ml_tools/ensemble_inference/_imprimir.py +0 -9
- ml_tools/ensemble_learning/_imprimir.py +0 -10
- ml_tools/excel_handler/_imprimir.py +0 -13
- ml_tools/keys/_imprimir.py +0 -11
- ml_tools/math_utilities/_imprimir.py +0 -11
- ml_tools/optimization_tools/_imprimir.py +0 -13
- ml_tools/path_manager/_imprimir.py +0 -15
- ml_tools/plot_fonts/_imprimir.py +0 -8
- ml_tools/schema/_imprimir.py +0 -10
- ml_tools/serde/_imprimir.py +0 -10
- ml_tools/utilities/_imprimir.py +0 -18
- {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.3.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.3.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.3.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.3.0.dist-info}/top_level.txt +0 -0
|
@@ -80,26 +80,12 @@ class DragonDetectionTrainer(_BaseDragonTrainer):
|
|
|
80
80
|
|
|
81
81
|
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
82
82
|
"""Initializes the DataLoaders with the object detection collate_fn."""
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
shuffle=shuffle,
|
|
90
|
-
num_workers=loader_workers,
|
|
91
|
-
pin_memory=("cuda" in self.device.type),
|
|
92
|
-
collate_fn=self.collate_fn, # Use the provided collate function
|
|
93
|
-
drop_last=True
|
|
94
|
-
)
|
|
95
|
-
|
|
96
|
-
self.validation_loader = DataLoader(
|
|
97
|
-
dataset=self.validation_dataset,
|
|
98
|
-
batch_size=batch_size,
|
|
99
|
-
shuffle=False,
|
|
100
|
-
num_workers=loader_workers,
|
|
101
|
-
pin_memory=("cuda" in self.device.type),
|
|
102
|
-
collate_fn=self.collate_fn # Use the provided collate function
|
|
83
|
+
self._make_dataloaders(
|
|
84
|
+
train_dataset=self.train_dataset,
|
|
85
|
+
validation_dataset=self.validation_dataset,
|
|
86
|
+
batch_size=batch_size,
|
|
87
|
+
shuffle=shuffle,
|
|
88
|
+
collate_fn=self.collate_fn
|
|
103
89
|
)
|
|
104
90
|
|
|
105
91
|
def _train_step(self):
|
|
@@ -207,17 +193,9 @@ class DragonDetectionTrainer(_BaseDragonTrainer):
|
|
|
207
193
|
- If 'current', use the current state of the trained model up the latest trained epoch.
|
|
208
194
|
test_data (DataLoader | Dataset | None): Optional Test data to evaluate the model performance. Validation and Test metrics will be saved to subdirectories.
|
|
209
195
|
"""
|
|
210
|
-
# Validate
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
elif model_checkpoint in [MagicWords.BEST, MagicWords.CURRENT]:
|
|
214
|
-
checkpoint_validated = model_checkpoint
|
|
215
|
-
else:
|
|
216
|
-
_LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.BEST}', or the string '{MagicWords.CURRENT}'.")
|
|
217
|
-
raise ValueError()
|
|
218
|
-
|
|
219
|
-
# Validate directory
|
|
220
|
-
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
196
|
+
# Validate inputs using base helpers
|
|
197
|
+
checkpoint_validated = self._validate_checkpoint_arg(model_checkpoint)
|
|
198
|
+
save_path = self._validate_save_dir(save_dir)
|
|
221
199
|
|
|
222
200
|
# Validate test data and dispatch
|
|
223
201
|
if test_data is not None:
|
|
@@ -230,21 +208,21 @@ class DragonDetectionTrainer(_BaseDragonTrainer):
|
|
|
230
208
|
test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
|
|
231
209
|
|
|
232
210
|
# Dispatch validation set
|
|
233
|
-
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
|
|
211
|
+
_LOGGER.info(f"🔎 Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
|
|
234
212
|
self._evaluate(save_dir=validation_metrics_path,
|
|
235
|
-
model_checkpoint=checkpoint_validated,
|
|
213
|
+
model_checkpoint=checkpoint_validated, # type: ignore
|
|
236
214
|
data=None) # 'None' triggers use of self.test_dataset
|
|
237
215
|
|
|
238
216
|
# Dispatch test set
|
|
239
|
-
_LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
|
|
217
|
+
_LOGGER.info(f"🔎 Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
|
|
240
218
|
self._evaluate(save_dir=test_metrics_path,
|
|
241
219
|
model_checkpoint="current", # Use 'current' state after loading checkpoint once
|
|
242
220
|
data=test_data_validated)
|
|
243
221
|
else:
|
|
244
222
|
# Dispatch validation set
|
|
245
|
-
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
|
|
223
|
+
_LOGGER.info(f"🔎 Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
|
|
246
224
|
self._evaluate(save_dir=save_path,
|
|
247
|
-
model_checkpoint=checkpoint_validated,
|
|
225
|
+
model_checkpoint=checkpoint_validated, # type: ignore
|
|
248
226
|
data=None) # 'None' triggers use of self.test_dataset
|
|
249
227
|
|
|
250
228
|
def _evaluate(self,
|
|
@@ -263,54 +241,17 @@ class DragonDetectionTrainer(_BaseDragonTrainer):
|
|
|
263
241
|
- If 'best', the best checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
|
|
264
242
|
- If 'current', use the current state of the trained model up the latest trained epoch.
|
|
265
243
|
"""
|
|
266
|
-
dataset_for_artifacts = None
|
|
267
|
-
eval_loader = None
|
|
268
|
-
|
|
269
244
|
# load model checkpoint
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
#
|
|
280
|
-
if isinstance(data, DataLoader):
|
|
281
|
-
eval_loader = data
|
|
282
|
-
if hasattr(data, 'dataset'):
|
|
283
|
-
dataset_for_artifacts = data.dataset # type: ignore
|
|
284
|
-
elif isinstance(data, Dataset):
|
|
285
|
-
# Create a new loader from the provided dataset
|
|
286
|
-
eval_loader = DataLoader(data,
|
|
287
|
-
batch_size=self._batch_size,
|
|
288
|
-
shuffle=False,
|
|
289
|
-
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
290
|
-
pin_memory=(self.device.type == "cuda"),
|
|
291
|
-
collate_fn=self.collate_fn)
|
|
292
|
-
dataset_for_artifacts = data
|
|
293
|
-
else: # data is None, use the trainer's default test dataset
|
|
294
|
-
if self.validation_dataset is None:
|
|
295
|
-
_LOGGER.error("Cannot evaluate. No data provided and no test_dataset available in the trainer.")
|
|
296
|
-
raise ValueError()
|
|
297
|
-
# Create a fresh DataLoader from the test_dataset
|
|
298
|
-
eval_loader = DataLoader(
|
|
299
|
-
self.validation_dataset,
|
|
300
|
-
batch_size=self._batch_size,
|
|
301
|
-
shuffle=False,
|
|
302
|
-
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
303
|
-
pin_memory=(self.device.type == "cuda"),
|
|
304
|
-
collate_fn=self.collate_fn
|
|
305
|
-
)
|
|
306
|
-
dataset_for_artifacts = self.validation_dataset
|
|
307
|
-
|
|
308
|
-
if eval_loader is None:
|
|
309
|
-
_LOGGER.error("Cannot evaluate. No valid data was provided or found.")
|
|
310
|
-
raise ValueError()
|
|
311
|
-
|
|
312
|
-
# print("\n--- Model Evaluation ---")
|
|
313
|
-
|
|
245
|
+
self._load_model_state_wrapper(model_checkpoint)
|
|
246
|
+
|
|
247
|
+
# Prepare Data using Base Helper
|
|
248
|
+
eval_loader, dataset_for_artifacts = self._prepare_eval_data(
|
|
249
|
+
data,
|
|
250
|
+
self.validation_dataset,
|
|
251
|
+
collate_fn=self.collate_fn # Important for Detection
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
# Gather all predictions and targets
|
|
314
255
|
all_predictions = []
|
|
315
256
|
all_targets = []
|
|
316
257
|
|
|
@@ -380,12 +321,8 @@ class DragonDetectionTrainer(_BaseDragonTrainer):
|
|
|
380
321
|
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeObjectDetection', but got {type(finalize_config).__name__}.")
|
|
381
322
|
raise TypeError()
|
|
382
323
|
|
|
383
|
-
# handle save path
|
|
384
|
-
dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
385
|
-
full_path = dir_path / finalize_config.filename
|
|
386
|
-
|
|
387
324
|
# handle checkpoint
|
|
388
|
-
self.
|
|
325
|
+
self._load_model_state_wrapper(model_checkpoint)
|
|
389
326
|
|
|
390
327
|
# Create finalized data
|
|
391
328
|
finalized_data = {
|
|
@@ -397,6 +334,9 @@ class DragonDetectionTrainer(_BaseDragonTrainer):
|
|
|
397
334
|
if finalize_config.class_map is not None:
|
|
398
335
|
finalized_data[PyTorchCheckpointKeys.CLASS_MAP] = finalize_config.class_map
|
|
399
336
|
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
337
|
+
# Save using base helper
|
|
338
|
+
self._save_finalized_artifact(
|
|
339
|
+
finalized_data=finalized_data,
|
|
340
|
+
save_dir=save_dir,
|
|
341
|
+
filename=finalize_config.filename
|
|
342
|
+
)
|
|
@@ -99,23 +99,11 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
|
|
|
99
99
|
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
100
100
|
"""Initializes the DataLoaders."""
|
|
101
101
|
# Ensure stability on MPS devices by setting num_workers to 0
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
shuffle=shuffle,
|
|
108
|
-
num_workers=loader_workers,
|
|
109
|
-
pin_memory=("cuda" in self.device.type),
|
|
110
|
-
drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
|
|
111
|
-
)
|
|
112
|
-
|
|
113
|
-
self.validation_loader = DataLoader(
|
|
114
|
-
dataset=self.validation_dataset,
|
|
115
|
-
batch_size=batch_size,
|
|
116
|
-
shuffle=False,
|
|
117
|
-
num_workers=loader_workers,
|
|
118
|
-
pin_memory=("cuda" in self.device.type)
|
|
102
|
+
self._make_dataloaders(
|
|
103
|
+
train_dataset=self.train_dataset,
|
|
104
|
+
validation_dataset=self.validation_dataset,
|
|
105
|
+
batch_size=batch_size,
|
|
106
|
+
shuffle=shuffle
|
|
119
107
|
)
|
|
120
108
|
|
|
121
109
|
def _train_step(self):
|
|
@@ -279,14 +267,9 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
|
|
|
279
267
|
val_format_configuration: Optional configuration for validation metrics.
|
|
280
268
|
test_format_configuration: Optional configuration for test metrics.
|
|
281
269
|
"""
|
|
282
|
-
# Validate
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
elif model_checkpoint in [MagicWords.BEST, MagicWords.CURRENT]:
|
|
286
|
-
checkpoint_validated = model_checkpoint
|
|
287
|
-
else:
|
|
288
|
-
_LOGGER.error(f"'model_checkpoint' must be a Path object, or '{MagicWords.BEST}', or '{MagicWords.CURRENT}'.")
|
|
289
|
-
raise ValueError()
|
|
270
|
+
# Validate inputs using base helpers
|
|
271
|
+
checkpoint_validated = self._validate_checkpoint_arg(model_checkpoint)
|
|
272
|
+
save_path = self._validate_save_dir(save_dir)
|
|
290
273
|
|
|
291
274
|
# Validate val configuration
|
|
292
275
|
if val_format_configuration is not None:
|
|
@@ -294,9 +277,6 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
|
|
|
294
277
|
_LOGGER.error(f"Invalid 'val_format_configuration': '{type(val_format_configuration)}'.")
|
|
295
278
|
raise ValueError()
|
|
296
279
|
|
|
297
|
-
# Validate directory
|
|
298
|
-
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
299
|
-
|
|
300
280
|
# Validate test data and dispatch
|
|
301
281
|
if test_data is not None:
|
|
302
282
|
if not isinstance(test_data, (DataLoader, Dataset)):
|
|
@@ -308,9 +288,9 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
|
|
|
308
288
|
test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
|
|
309
289
|
|
|
310
290
|
# Dispatch validation set
|
|
311
|
-
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
|
|
291
|
+
_LOGGER.info(f"🔎 Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
|
|
312
292
|
self._evaluate(save_dir=validation_metrics_path,
|
|
313
|
-
model_checkpoint=checkpoint_validated,
|
|
293
|
+
model_checkpoint=checkpoint_validated, # type: ignore
|
|
314
294
|
data=None,
|
|
315
295
|
format_configuration=val_format_configuration)
|
|
316
296
|
|
|
@@ -329,16 +309,16 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
|
|
|
329
309
|
test_configuration_validated = test_format_configuration
|
|
330
310
|
|
|
331
311
|
# Dispatch test set
|
|
332
|
-
_LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
|
|
312
|
+
_LOGGER.info(f"🔎 Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
|
|
333
313
|
self._evaluate(save_dir=test_metrics_path,
|
|
334
314
|
model_checkpoint="current",
|
|
335
315
|
data=test_data_validated,
|
|
336
316
|
format_configuration=test_configuration_validated)
|
|
337
317
|
else:
|
|
338
318
|
# Dispatch validation set
|
|
339
|
-
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
|
|
319
|
+
_LOGGER.info(f"🔎 Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
|
|
340
320
|
self._evaluate(save_dir=save_path,
|
|
341
|
-
model_checkpoint=checkpoint_validated,
|
|
321
|
+
model_checkpoint=checkpoint_validated, # type: ignore
|
|
342
322
|
data=None,
|
|
343
323
|
format_configuration=val_format_configuration)
|
|
344
324
|
|
|
@@ -350,42 +330,13 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
|
|
|
350
330
|
"""
|
|
351
331
|
Private evaluation helper.
|
|
352
332
|
"""
|
|
353
|
-
eval_loader = None
|
|
354
|
-
|
|
355
333
|
# load model checkpoint
|
|
356
|
-
|
|
357
|
-
self._load_checkpoint(path=model_checkpoint)
|
|
358
|
-
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
|
|
359
|
-
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
360
|
-
self._load_checkpoint(path_to_latest)
|
|
361
|
-
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
|
|
362
|
-
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
|
|
363
|
-
raise ValueError()
|
|
334
|
+
self._load_model_state_wrapper(model_checkpoint)
|
|
364
335
|
|
|
365
|
-
#
|
|
366
|
-
|
|
367
|
-
eval_loader = data
|
|
368
|
-
elif isinstance(data, Dataset):
|
|
369
|
-
# Create a new loader from the provided dataset
|
|
370
|
-
eval_loader = DataLoader(data,
|
|
371
|
-
batch_size=self._batch_size,
|
|
372
|
-
shuffle=False,
|
|
373
|
-
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
374
|
-
pin_memory=(self.device.type == "cuda"))
|
|
375
|
-
else: # data is None, use the trainer's default validation dataset
|
|
376
|
-
if self.validation_dataset is None:
|
|
377
|
-
_LOGGER.error("Cannot evaluate. No data provided and no validation_dataset available in the trainer.")
|
|
378
|
-
raise ValueError()
|
|
379
|
-
eval_loader = DataLoader(self.validation_dataset,
|
|
380
|
-
batch_size=self._batch_size,
|
|
381
|
-
shuffle=False,
|
|
382
|
-
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
383
|
-
pin_memory=(self.device.type == "cuda"))
|
|
384
|
-
|
|
385
|
-
if eval_loader is None:
|
|
386
|
-
_LOGGER.error("Cannot evaluate. No valid data was provided or found.")
|
|
387
|
-
raise ValueError()
|
|
336
|
+
# Prepare Data using Base Helper
|
|
337
|
+
eval_loader, _ = self._prepare_eval_data(data, self.validation_dataset)
|
|
388
338
|
|
|
339
|
+
# Gather Predictions
|
|
389
340
|
all_preds, _, all_true = [], [], []
|
|
390
341
|
for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
|
|
391
342
|
if y_pred_b is not None: all_preds.append(y_pred_b)
|
|
@@ -514,13 +465,9 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
|
|
|
514
465
|
elif self.kind == MLTaskKeys.SEQUENCE_VALUE and not isinstance(finalize_config, FinalizeSequenceValuePrediction):
|
|
515
466
|
_LOGGER.error(f"Received a wrong finalize configuration for task {self.kind}: {type(finalize_config).__name__}.")
|
|
516
467
|
raise TypeError()
|
|
517
|
-
|
|
518
|
-
# handle save path
|
|
519
|
-
dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
520
|
-
full_path = dir_path / finalize_config.filename
|
|
521
468
|
|
|
522
469
|
# handle checkpoint
|
|
523
|
-
self.
|
|
470
|
+
self._load_model_state_wrapper(model_checkpoint)
|
|
524
471
|
|
|
525
472
|
# Create finalized data
|
|
526
473
|
finalized_data = {
|
|
@@ -534,7 +481,10 @@ class DragonSequenceTrainer(_BaseDragonTrainer):
|
|
|
534
481
|
if finalize_config.initial_sequence is not None:
|
|
535
482
|
finalized_data[PyTorchCheckpointKeys.INITIAL_SEQUENCE] = finalize_config.initial_sequence
|
|
536
483
|
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
484
|
+
# Save using base helper
|
|
485
|
+
self._save_finalized_artifact(
|
|
486
|
+
finalized_data=finalized_data,
|
|
487
|
+
save_dir=save_dir,
|
|
488
|
+
filename=finalize_config.filename
|
|
489
|
+
)
|
|
540
490
|
|
|
@@ -142,23 +142,11 @@ class DragonTrainer(_BaseDragonTrainer):
|
|
|
142
142
|
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
143
143
|
"""Initializes the DataLoaders."""
|
|
144
144
|
# Ensure stability on MPS devices by setting num_workers to 0
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
shuffle=shuffle,
|
|
151
|
-
num_workers=loader_workers,
|
|
152
|
-
pin_memory=("cuda" in self.device.type),
|
|
153
|
-
drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
|
|
154
|
-
)
|
|
155
|
-
|
|
156
|
-
self.validation_loader = DataLoader(
|
|
157
|
-
dataset=self.validation_dataset,
|
|
158
|
-
batch_size=batch_size,
|
|
159
|
-
shuffle=False,
|
|
160
|
-
num_workers=loader_workers,
|
|
161
|
-
pin_memory=("cuda" in self.device.type)
|
|
145
|
+
self._make_dataloaders(
|
|
146
|
+
train_dataset=self.train_dataset,
|
|
147
|
+
validation_dataset=self.validation_dataset,
|
|
148
|
+
batch_size=batch_size,
|
|
149
|
+
shuffle=shuffle
|
|
162
150
|
)
|
|
163
151
|
|
|
164
152
|
def _train_step(self):
|
|
@@ -403,14 +391,9 @@ class DragonTrainer(_BaseDragonTrainer):
|
|
|
403
391
|
val_format_configuration (object): Optional configuration for metric format output for the validation set.
|
|
404
392
|
test_format_configuration (object): Optional configuration for metric format output for the test set.
|
|
405
393
|
"""
|
|
406
|
-
# Validate
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
elif model_checkpoint in [MagicWords.BEST, MagicWords.CURRENT]:
|
|
410
|
-
checkpoint_validated = model_checkpoint
|
|
411
|
-
else:
|
|
412
|
-
_LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.BEST}', or the string '{MagicWords.CURRENT}'.")
|
|
413
|
-
raise ValueError()
|
|
394
|
+
# Validate inputs using base helpers
|
|
395
|
+
checkpoint_validated = self._validate_checkpoint_arg(model_checkpoint)
|
|
396
|
+
save_path = self._validate_save_dir(save_dir)
|
|
414
397
|
|
|
415
398
|
# Validate classification threshold
|
|
416
399
|
if self.kind not in MLTaskKeys.ALL_BINARY_TASKS:
|
|
@@ -445,9 +428,6 @@ class DragonTrainer(_BaseDragonTrainer):
|
|
|
445
428
|
else: # config is None
|
|
446
429
|
val_configuration_validated = None
|
|
447
430
|
|
|
448
|
-
# Validate directory
|
|
449
|
-
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
450
|
-
|
|
451
431
|
# Validate test data and dispatch
|
|
452
432
|
if test_data is not None:
|
|
453
433
|
if not isinstance(test_data, (DataLoader, Dataset)):
|
|
@@ -461,7 +441,7 @@ class DragonTrainer(_BaseDragonTrainer):
|
|
|
461
441
|
# Dispatch validation set
|
|
462
442
|
_LOGGER.info(f"🔎 Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
|
|
463
443
|
self._evaluate(save_dir=validation_metrics_path,
|
|
464
|
-
model_checkpoint=checkpoint_validated,
|
|
444
|
+
model_checkpoint=checkpoint_validated, # type: ignore
|
|
465
445
|
classification_threshold=threshold_validated,
|
|
466
446
|
data=None,
|
|
467
447
|
format_configuration=val_configuration_validated)
|
|
@@ -499,9 +479,9 @@ class DragonTrainer(_BaseDragonTrainer):
|
|
|
499
479
|
format_configuration=test_configuration_validated)
|
|
500
480
|
else:
|
|
501
481
|
# Dispatch validation set
|
|
502
|
-
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
|
|
482
|
+
_LOGGER.info(f"🔎 Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
|
|
503
483
|
self._evaluate(save_dir=save_path,
|
|
504
|
-
model_checkpoint=checkpoint_validated,
|
|
484
|
+
model_checkpoint=checkpoint_validated, # type: ignore
|
|
505
485
|
classification_threshold=threshold_validated,
|
|
506
486
|
data=None,
|
|
507
487
|
format_configuration=val_configuration_validated)
|
|
@@ -525,55 +505,16 @@ class DragonTrainer(_BaseDragonTrainer):
|
|
|
525
505
|
"""
|
|
526
506
|
Changed to a private helper function.
|
|
527
507
|
"""
|
|
528
|
-
dataset_for_artifacts = None
|
|
529
|
-
eval_loader = None
|
|
530
|
-
|
|
531
508
|
# set threshold
|
|
532
509
|
self._classification_threshold = classification_threshold
|
|
533
510
|
|
|
534
511
|
# load model checkpoint
|
|
535
|
-
|
|
536
|
-
self._load_checkpoint(path=model_checkpoint)
|
|
537
|
-
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
|
|
538
|
-
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
539
|
-
self._load_checkpoint(path_to_latest)
|
|
540
|
-
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
|
|
541
|
-
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
|
|
542
|
-
raise ValueError()
|
|
512
|
+
self._load_model_state_wrapper(model_checkpoint)
|
|
543
513
|
|
|
544
|
-
#
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
if hasattr(data, 'dataset'):
|
|
549
|
-
dataset_for_artifacts = data.dataset # type: ignore
|
|
550
|
-
elif isinstance(data, Dataset):
|
|
551
|
-
# Create a new loader from the provided dataset
|
|
552
|
-
eval_loader = DataLoader(data,
|
|
553
|
-
batch_size=self._batch_size,
|
|
554
|
-
shuffle=False,
|
|
555
|
-
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
556
|
-
pin_memory=(self.device.type == "cuda"))
|
|
557
|
-
dataset_for_artifacts = data
|
|
558
|
-
else: # data is None, use the trainer's default test dataset
|
|
559
|
-
if self.validation_dataset is None:
|
|
560
|
-
_LOGGER.error("Cannot evaluate. No data provided and no validation dataset available in the trainer.")
|
|
561
|
-
raise ValueError()
|
|
562
|
-
# Create a fresh DataLoader from the test_dataset
|
|
563
|
-
eval_loader = DataLoader(self.validation_dataset,
|
|
564
|
-
batch_size=self._batch_size,
|
|
565
|
-
shuffle=False,
|
|
566
|
-
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
567
|
-
pin_memory=(self.device.type == "cuda"))
|
|
568
|
-
|
|
569
|
-
dataset_for_artifacts = self.validation_dataset
|
|
570
|
-
|
|
571
|
-
if eval_loader is None:
|
|
572
|
-
_LOGGER.error("Cannot evaluate. No valid data was provided or found.")
|
|
573
|
-
raise ValueError()
|
|
574
|
-
|
|
575
|
-
# print("\n--- Model Evaluation ---")
|
|
576
|
-
|
|
514
|
+
# Prepare Data using Base Helper
|
|
515
|
+
eval_loader, dataset_for_artifacts = self._prepare_eval_data(data, self.validation_dataset)
|
|
516
|
+
|
|
517
|
+
# Gather Predictions
|
|
577
518
|
all_preds, all_probs, all_true = [], [], []
|
|
578
519
|
for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
|
|
579
520
|
if y_pred_b is not None: all_preds.append(y_pred_b)
|
|
@@ -1128,13 +1069,9 @@ class DragonTrainer(_BaseDragonTrainer):
|
|
|
1128
1069
|
elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION and not isinstance(finalize_config, FinalizeMultiLabelBinaryClassification):
|
|
1129
1070
|
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiLabelBinaryClassification', but got {type(finalize_config).__name__}.")
|
|
1130
1071
|
raise TypeError()
|
|
1131
|
-
|
|
1132
|
-
# handle save path
|
|
1133
|
-
dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
1134
|
-
full_path = dir_path / finalize_config.filename
|
|
1135
|
-
|
|
1072
|
+
|
|
1136
1073
|
# handle checkpoint
|
|
1137
|
-
self.
|
|
1074
|
+
self._load_model_state_wrapper(model_checkpoint)
|
|
1138
1075
|
|
|
1139
1076
|
# Create finalized data
|
|
1140
1077
|
finalized_data = {
|
|
@@ -1153,8 +1090,10 @@ class DragonTrainer(_BaseDragonTrainer):
|
|
|
1153
1090
|
if finalize_config.class_map is not None:
|
|
1154
1091
|
finalized_data[PyTorchCheckpointKeys.CLASS_MAP] = finalize_config.class_map
|
|
1155
1092
|
|
|
1156
|
-
# Save model file
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1093
|
+
# Save model file using base helper
|
|
1094
|
+
self._save_finalized_artifact(
|
|
1095
|
+
finalized_data=finalized_data,
|
|
1096
|
+
save_dir=save_dir,
|
|
1097
|
+
filename=finalize_config.filename
|
|
1098
|
+
)
|
|
1160
1099
|
|
|
@@ -16,7 +16,7 @@ from ._train_tools import (
|
|
|
16
16
|
save_pretrained_transforms,
|
|
17
17
|
)
|
|
18
18
|
|
|
19
|
-
from
|
|
19
|
+
from .._core import _imprimir_disponibles
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
__all__ = [
|
|
@@ -30,3 +30,7 @@ __all__ = [
|
|
|
30
30
|
"save_pretrained_transforms",
|
|
31
31
|
"select_features_by_shap"
|
|
32
32
|
]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def info():
|
|
36
|
+
_imprimir_disponibles(__all__)
|