dragon-ml-toolbox 13.3.0__py3-none-any.whl → 16.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/METADATA +20 -6
  2. dragon_ml_toolbox-16.2.0.dist-info/RECORD +51 -0
  3. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
  4. ml_tools/ETL_cleaning.py +20 -20
  5. ml_tools/ETL_engineering.py +23 -25
  6. ml_tools/GUI_tools.py +20 -20
  7. ml_tools/MICE_imputation.py +207 -5
  8. ml_tools/ML_callbacks.py +43 -26
  9. ml_tools/ML_configuration.py +788 -0
  10. ml_tools/ML_datasetmaster.py +303 -448
  11. ml_tools/ML_evaluation.py +351 -93
  12. ml_tools/ML_evaluation_multi.py +139 -42
  13. ml_tools/ML_inference.py +290 -209
  14. ml_tools/ML_models.py +33 -106
  15. ml_tools/ML_models_advanced.py +323 -0
  16. ml_tools/ML_optimization.py +12 -12
  17. ml_tools/ML_scaler.py +11 -11
  18. ml_tools/ML_sequence_datasetmaster.py +341 -0
  19. ml_tools/ML_sequence_evaluation.py +219 -0
  20. ml_tools/ML_sequence_inference.py +391 -0
  21. ml_tools/ML_sequence_models.py +139 -0
  22. ml_tools/ML_trainer.py +1604 -179
  23. ml_tools/ML_utilities.py +351 -4
  24. ml_tools/ML_vision_datasetmaster.py +1540 -0
  25. ml_tools/ML_vision_evaluation.py +284 -0
  26. ml_tools/ML_vision_inference.py +405 -0
  27. ml_tools/ML_vision_models.py +641 -0
  28. ml_tools/ML_vision_transformers.py +284 -0
  29. ml_tools/PSO_optimization.py +6 -6
  30. ml_tools/SQL.py +4 -4
  31. ml_tools/_keys.py +171 -0
  32. ml_tools/_schema.py +1 -1
  33. ml_tools/custom_logger.py +37 -14
  34. ml_tools/data_exploration.py +502 -93
  35. ml_tools/ensemble_evaluation.py +54 -11
  36. ml_tools/ensemble_inference.py +7 -33
  37. ml_tools/ensemble_learning.py +1 -1
  38. ml_tools/math_utilities.py +1 -1
  39. ml_tools/optimization_tools.py +2 -2
  40. ml_tools/path_manager.py +5 -5
  41. ml_tools/serde.py +2 -2
  42. ml_tools/utilities.py +192 -4
  43. dragon_ml_toolbox-13.3.0.dist-info/RECORD +0 -41
  44. ml_tools/RNN_forecast.py +0 -56
  45. ml_tools/keys.py +0 -87
  46. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/WHEEL +0 -0
  47. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE +0 -0
  48. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_utilities.py CHANGED
@@ -1,18 +1,26 @@
1
1
  import pandas as pd
2
2
  from pathlib import Path
3
- from typing import Union, Any, Optional
3
+ from typing import Union, Any, Optional, Dict, List, Iterable
4
+ import torch
5
+ from torch import nn
4
6
 
5
7
  from .path_manager import make_fullpath, list_subdirectories, list_files_by_extension
6
8
  from ._script_info import _script_info
7
9
  from ._logger import _LOGGER
8
- from .keys import DatasetKeys, PytorchModelArchitectureKeys, PytorchArtifactPathKeys, SHAPKeys
10
+ from ._keys import DatasetKeys, PytorchModelArchitectureKeys, PytorchArtifactPathKeys, SHAPKeys, UtilityKeys, PyTorchCheckpointKeys
9
11
  from .utilities import load_dataframe
10
- from .custom_logger import save_list_strings
12
+ from .custom_logger import save_list_strings, custom_logger
13
+ from .serde import serialize_object_filename
11
14
 
12
15
 
13
16
  __all__ = [
14
17
  "find_model_artifacts",
15
- "select_features_by_shap"
18
+ "select_features_by_shap",
19
+ "get_model_parameters",
20
+ "inspect_model_architecture",
21
+ "inspect_pth_file",
22
+ "set_parameter_requires_grad",
23
+ "save_pretrained_transforms"
16
24
  ]
17
25
 
18
26
 
@@ -226,5 +234,344 @@ def select_features_by_shap(
226
234
  return final_features
227
235
 
228
236
 
237
+ def get_model_parameters(model: nn.Module, save_dir: Optional[Union[str,Path]]=None) -> Dict[str, int]:
238
+ """
239
+ Calculates the total and trainable parameters of a PyTorch model.
240
+
241
+ Args:
242
+ model (nn.Module): The PyTorch model to inspect.
243
+ save_dir: Optional directory to save the output as a JSON file.
244
+
245
+ Returns:
246
+ Dict[str, int]: A dictionary containing:
247
+ - "total_params": The total number of parameters.
248
+ - "trainable_params": The number of trainable parameters (where requires_grad=True).
249
+ """
250
+ total_params = sum(p.numel() for p in model.parameters())
251
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
252
+
253
+ report = {
254
+ UtilityKeys.TOTAL_PARAMS: total_params,
255
+ UtilityKeys.TRAINABLE_PARAMS: trainable_params
256
+ }
257
+
258
+ if save_dir is not None:
259
+ output_dir = make_fullpath(save_dir, make=True, enforce="directory")
260
+ custom_logger(data=report,
261
+ save_directory=output_dir,
262
+ log_name=UtilityKeys.MODEL_PARAMS_FILE,
263
+ add_timestamp=False,
264
+ dict_as="json")
265
+
266
+ return report
267
+
268
+
269
+ def inspect_model_architecture(
270
+ model: nn.Module,
271
+ save_dir: Union[str, Path]
272
+ ) -> None:
273
+ """
274
+ Saves a human-readable text summary of a model's instantiated
275
+ architecture, including parameter counts.
276
+
277
+ Args:
278
+ model (nn.Module): The PyTorch model to inspect.
279
+ save_dir (str | Path): Directory to save the text file.
280
+ """
281
+ # --- 1. Validate path ---
282
+ output_dir = make_fullpath(save_dir, make=True, enforce="directory")
283
+ architecture_filename = UtilityKeys.MODEL_ARCHITECTURE_FILE + ".txt"
284
+ filepath = output_dir / architecture_filename
285
+
286
+ # --- 2. Get parameter counts from existing function ---
287
+ try:
288
+ params_report = get_model_parameters(model) # Get dict, don't save
289
+ total = params_report.get(UtilityKeys.TOTAL_PARAMS, 'N/A')
290
+ trainable = params_report.get(UtilityKeys.TRAINABLE_PARAMS, 'N/A')
291
+ header = (
292
+ f"Model: {model.__class__.__name__}\n"
293
+ f"Total Parameters: {total:,}\n"
294
+ f"Trainable Parameters: {trainable:,}\n"
295
+ f"{'='*80}\n\n"
296
+ )
297
+ except Exception as e:
298
+ _LOGGER.warning(f"Could not get model parameters: {e}")
299
+ header = f"Model: {model.__class__.__name__}\n{'='*80}\n\n"
300
+
301
+ # --- 3. Get architecture string ---
302
+ arch_string = str(model)
303
+
304
+ # --- 4. Write to file ---
305
+ try:
306
+ with open(filepath, 'w', encoding='utf-8') as f:
307
+ f.write(header)
308
+ f.write(arch_string)
309
+ _LOGGER.info(f"Model architecture summary saved to '{filepath.name}'")
310
+ except Exception as e:
311
+ _LOGGER.error(f"Failed to write model architecture file: {e}")
312
+ raise
313
+
314
+
315
+ def inspect_pth_file(
316
+ pth_path: Union[str, Path],
317
+ save_dir: Union[str, Path],
318
+ ) -> None:
319
+ """
320
+ Inspects a .pth file (e.g., checkpoint) and saves a human-readable
321
+ JSON summary of its contents.
322
+
323
+ Args:
324
+ pth_path (str | Path): The path to the .pth file to inspect.
325
+ save_dir (str | Path): The directory to save the JSON report.
326
+
327
+ Returns:
328
+ Dict (str, Any): A dictionary containing the inspection report.
329
+
330
+ Raises:
331
+ ValueError: If the .pth file is empty or in an unrecognized format.
332
+ """
333
+ # --- 1. Validate paths ---
334
+ pth_file = make_fullpath(pth_path, enforce="file")
335
+ output_dir = make_fullpath(save_dir, make=True, enforce="directory")
336
+ pth_name = pth_file.stem
337
+
338
+ # --- 2. Load data ---
339
+ try:
340
+ # Load onto CPU to avoid GPU memory issues
341
+ loaded_data = torch.load(pth_file, map_location=torch.device('cpu'))
342
+ except Exception as e:
343
+ _LOGGER.error(f"Failed to load .pth file '{pth_file}': {e}")
344
+ raise
345
+
346
+ # --- 3. Initialize Report ---
347
+ report = {
348
+ "top_level_type": str(type(loaded_data)),
349
+ "top_level_summary": {},
350
+ "model_state_analysis": None,
351
+ "notes": []
352
+ }
353
+
354
+ # --- 4. Parse loaded data ---
355
+ if isinstance(loaded_data, dict):
356
+ # --- Case 1: Loaded data is a dictionary (most common case) ---
357
+ # "main loop" that iterates over *everything* first.
358
+ for key, value in loaded_data.items():
359
+ key_summary = {}
360
+ val_type = str(type(value))
361
+ key_summary["type"] = val_type
362
+
363
+ if isinstance(value, torch.Tensor):
364
+ key_summary["shape"] = list(value.shape)
365
+ key_summary["dtype"] = str(value.dtype)
366
+ elif isinstance(value, dict):
367
+ key_summary["key_count"] = len(value)
368
+ key_summary["key_preview"] = list(value.keys())[:5]
369
+ elif isinstance(value, (int, float, str, bool)):
370
+ key_summary["value_preview"] = str(value)
371
+ elif isinstance(value, (list, tuple)):
372
+ key_summary["value_preview"] = str(value)[:100]
373
+
374
+ report["top_level_summary"][key] = key_summary
375
+
376
+ # Now, try to find the model state_dict within the dict
377
+ if PyTorchCheckpointKeys.MODEL_STATE in loaded_data and isinstance(loaded_data[PyTorchCheckpointKeys.MODEL_STATE], dict):
378
+ report["notes"].append(f"Found standard checkpoint key: '{PyTorchCheckpointKeys.MODEL_STATE}'. Analyzing as model state_dict.")
379
+ state_dict = loaded_data[PyTorchCheckpointKeys.MODEL_STATE]
380
+ report["model_state_analysis"] = _generate_weight_report(state_dict)
381
+
382
+ elif all(isinstance(v, torch.Tensor) for v in loaded_data.values()):
383
+ report["notes"].append("File dictionary contains only tensors. Analyzing entire dictionary as model state_dict.")
384
+ state_dict = loaded_data
385
+ report["model_state_analysis"] = _generate_weight_report(state_dict)
386
+
387
+ else:
388
+ report["notes"].append("Could not identify a single model state_dict. See top_level_summary for all contents. No detailed weight analysis will be performed.")
389
+
390
+ elif isinstance(loaded_data, nn.Module):
391
+ # --- Case 2: Loaded data is a full pickled model ---
392
+ # _LOGGER.warning("Loading a full, pickled nn.Module is not recommended. Inspecting its state_dict().")
393
+ report["notes"].append("File is a full, pickled nn.Module. This is not recommended. Extracting state_dict() for analysis.")
394
+ state_dict = loaded_data.state_dict()
395
+ report["model_state_analysis"] = _generate_weight_report(state_dict)
396
+
397
+ else:
398
+ # --- Case 3: Unrecognized format (e.g., single tensor, list) ---
399
+ _LOGGER.error(f"Could not parse .pth file. Loaded data is of type {type(loaded_data)}, not a dict or nn.Module.")
400
+ raise ValueError()
401
+
402
+ # --- 5. Save Report ---
403
+ custom_logger(data=report,
404
+ save_directory=output_dir,
405
+ log_name=UtilityKeys.PTH_FILE + pth_name,
406
+ add_timestamp=False,
407
+ dict_as="json")
408
+
409
+
410
+ def _generate_weight_report(state_dict: dict) -> dict:
411
+ """
412
+ Internal helper to analyze a state_dict and return a structured report.
413
+
414
+ Args:
415
+ state_dict (dict): The model state_dict to analyze.
416
+
417
+ Returns:
418
+ dict: A report containing total parameters and a per-parameter breakdown.
419
+ """
420
+ weight_report = {}
421
+ total_params = 0
422
+ if not isinstance(state_dict, dict):
423
+ _LOGGER.warning(f"Attempted to generate weight report on non-dict type: {type(state_dict)}")
424
+ return {"error": "Input was not a dictionary."}
425
+
426
+ for key, tensor in state_dict.items():
427
+ if not isinstance(tensor, torch.Tensor):
428
+ _LOGGER.warning(f"Skipping key '{key}' in state_dict: value is not a tensor (type: {type(tensor)}).")
429
+ weight_report[key] = {
430
+ "type": str(type(tensor)),
431
+ "value_preview": str(tensor)[:50] # Show a preview
432
+ }
433
+ continue
434
+ weight_report[key] = {
435
+ "shape": list(tensor.shape),
436
+ "dtype": str(tensor.dtype),
437
+ "requires_grad": tensor.requires_grad,
438
+ "num_elements": tensor.numel()
439
+ }
440
+ total_params += tensor.numel()
441
+
442
+ return {
443
+ "total_parameters": total_params,
444
+ "parameter_key_count": len(weight_report),
445
+ "parameters": weight_report
446
+ }
447
+
448
+
449
+ def set_parameter_requires_grad(
450
+ model: nn.Module,
451
+ unfreeze_last_n_params: int,
452
+ ) -> int:
453
+ """
454
+ Freezes or unfreezes parameters in a model based on unfreeze_last_n_params.
455
+
456
+ - N = 0: Freezes ALL parameters.
457
+ - N > 0 and N < total: Freezes ALL parameters, then unfreezes the last N.
458
+ - N >= total: Unfreezes ALL parameters.
459
+
460
+ Note: 'N' refers to individual parameter tensors (e.g., `layer.weight`
461
+ or `layer.bias`), not modules or layers. For example, to unfreeze
462
+ the final nn.Linear layer, you would use N=2 (for its weight and bias).
463
+
464
+ Args:
465
+ model (nn.Module): The model to modify.
466
+ unfreeze_last_n_params (int):
467
+ The number of parameter tensors to unfreeze, starting from
468
+ the end of the model.
469
+
470
+ Returns:
471
+ int: The total number of individual parameters (elements) that were set to `requires_grad=True`.
472
+ """
473
+ if unfreeze_last_n_params < 0:
474
+ _LOGGER.error(f"unfreeze_last_n_params must be >= 0, but got {unfreeze_last_n_params}")
475
+ raise ValueError()
476
+
477
+ # --- Step 1: Get all parameter tensors ---
478
+ all_params = list(model.parameters())
479
+ total_param_tensors = len(all_params)
480
+
481
+ # --- Case 1: N = 0 (Freeze ALL parameters) ---
482
+ # early exit for the "freeze all" case.
483
+ if unfreeze_last_n_params == 0:
484
+ params_frozen = _set_params_grad(all_params, requires_grad=False)
485
+ _LOGGER.warning(f"Froze all {total_param_tensors} parameter tensors ({params_frozen} total elements).")
486
+ return 0 # 0 parameters unfrozen
487
+
488
+ # --- Case 2: N >= total (Unfreeze ALL parameters) ---
489
+ if unfreeze_last_n_params >= total_param_tensors:
490
+ if unfreeze_last_n_params > total_param_tensors:
491
+ _LOGGER.warning(f"Requested to unfreeze {unfreeze_last_n_params} params, but model only has {total_param_tensors}. Unfreezing all.")
492
+
493
+ params_unfrozen = _set_params_grad(all_params, requires_grad=True)
494
+ _LOGGER.info(f"Unfroze all {total_param_tensors} parameter tensors ({params_unfrozen} total elements) for training.")
495
+ return params_unfrozen
496
+
497
+ # --- Case 3: 0 < N < total (Standard: Freeze all, unfreeze last N) ---
498
+ # Freeze ALL
499
+ params_frozen = _set_params_grad(all_params, requires_grad=False)
500
+ _LOGGER.info(f"Froze {params_frozen} parameters.")
501
+
502
+ # Unfreeze the last N
503
+ params_to_unfreeze = all_params[-unfreeze_last_n_params:]
504
+
505
+ # these are all False, so the helper will set them to True
506
+ params_unfrozen = _set_params_grad(params_to_unfreeze, requires_grad=True)
507
+
508
+ _LOGGER.info(f"Unfroze the last {unfreeze_last_n_params} parameter tensors ({params_unfrozen} total elements) for training.")
509
+
510
+ return params_unfrozen
511
+
512
+
513
+ def _set_params_grad(
514
+ params: Iterable[nn.Parameter],
515
+ requires_grad: bool
516
+ ) -> int:
517
+ """
518
+ A helper function to set the `requires_grad` attribute for an iterable
519
+ of parameters and return the total number of elements changed.
520
+ """
521
+ params_changed = 0
522
+ for param in params:
523
+ if param.requires_grad != requires_grad:
524
+ param.requires_grad = requires_grad
525
+ params_changed += param.numel()
526
+ return params_changed
527
+
528
+
529
+ def save_pretrained_transforms(model: nn.Module, output_dir: Union[str, Path]):
530
+ """
531
+ Checks a model for the 'self._pretrained_default_transforms' attribute, if found,
532
+ serializes the returned transform object as a .joblib file.
533
+
534
+ This saves the callable transform object itself for
535
+ later use, such as passing it directly to the 'transform_source'
536
+ argument of the PyTorchVisionInferenceHandler.
537
+
538
+ Args:
539
+ model (nn.Module): The model instance to check.
540
+ output_dir (str | Path): The directory where the transform file will be saved.
541
+ """
542
+ output_filename = "pretrained_model_transformations"
543
+
544
+ # 1. Check for the "secret attribute"
545
+ if not hasattr(model, '_pretrained_default_transforms'):
546
+ _LOGGER.warning(f"Model of type {type(model).__name__} does not have the required attribute. No transformations saved.")
547
+ return
548
+
549
+ # 2. Get the transform object
550
+ try:
551
+ transform_obj = model._pretrained_default_transforms
552
+ except Exception as e:
553
+ _LOGGER.error(f"Error calling the required attribute on model: {e}")
554
+ return
555
+
556
+ # 3. Check if the object is actually there
557
+ if transform_obj is None:
558
+ _LOGGER.warning(f"Model {type(model).__name__} has the required attribute but returned None. No transforms saved.")
559
+ return
560
+
561
+ # 4. Serialize and save using serde
562
+ try:
563
+ serialize_object_filename(
564
+ obj=transform_obj,
565
+ save_dir=output_dir,
566
+ filename=output_filename,
567
+ verbose=True,
568
+ raise_on_error=True
569
+ )
570
+ # _LOGGER.info(f"Successfully saved pretrained transforms to '{output_dir}'.")
571
+ except Exception as e:
572
+ _LOGGER.error(f"Failed to serialize transformations: {e}")
573
+ raise
574
+
575
+
229
576
  def info():
230
577
  _script_info(__all__)