dragon-ml-toolbox 13.7.0__tar.gz → 14.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-13.7.0/dragon_ml_toolbox.egg-info → dragon_ml_toolbox-14.0.0}/PKG-INFO +2 -1
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0/dragon_ml_toolbox.egg-info}/PKG-INFO +2 -1
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/dragon_ml_toolbox.egg-info/SOURCES.txt +7 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/dragon_ml_toolbox.egg-info/requires.txt +1 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/ML_datasetmaster.py +2 -185
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/ML_evaluation.py +3 -3
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/ML_inference.py +0 -1
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/ML_models.py +3 -1
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/ML_trainer.py +446 -11
- dragon_ml_toolbox-14.0.0/ml_tools/ML_utilities.py +528 -0
- dragon_ml_toolbox-14.0.0/ml_tools/ML_vision_datasetmaster.py +1315 -0
- dragon_ml_toolbox-14.0.0/ml_tools/ML_vision_evaluation.py +260 -0
- dragon_ml_toolbox-14.0.0/ml_tools/ML_vision_inference.py +428 -0
- dragon_ml_toolbox-14.0.0/ml_tools/ML_vision_models.py +627 -0
- dragon_ml_toolbox-14.0.0/ml_tools/ML_vision_transformers.py +58 -0
- dragon_ml_toolbox-14.0.0/ml_tools/_ML_pytorch_tabular.py +543 -0
- dragon_ml_toolbox-14.0.0/ml_tools/_ML_vision_recipe.py +88 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/custom_logger.py +37 -14
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/keys.py +38 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/pyproject.toml +12 -2
- dragon_ml_toolbox-13.7.0/ml_tools/ML_utilities.py +0 -230
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/LICENSE +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/README.md +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/dragon_ml_toolbox.egg-info/dependency_links.txt +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/dragon_ml_toolbox.egg-info/top_level.txt +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/ETL_cleaning.py +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/ETL_engineering.py +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/GUI_tools.py +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/MICE_imputation.py +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/ML_callbacks.py +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/ML_evaluation_multi.py +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/ML_optimization.py +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/ML_scaler.py +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/PSO_optimization.py +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/RNN_forecast.py +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/SQL.py +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/VIF_factor.py +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/__init__.py +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/_logger.py +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/_schema.py +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/_script_info.py +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/constants.py +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/data_exploration.py +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/ensemble_evaluation.py +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/ensemble_inference.py +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/ensemble_learning.py +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/handle_excel.py +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/math_utilities.py +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/optimization_tools.py +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/path_manager.py +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/serde.py +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/ml_tools/utilities.py +0 -0
- {dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dragon-ml-toolbox
|
|
3
|
-
Version:
|
|
3
|
+
Version: 14.0.0
|
|
4
4
|
Summary: A collection of tools for data science and machine learning projects.
|
|
5
5
|
Author-email: "Karl L. Loza Vidaurre" <luigiloza@gmail.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -34,6 +34,7 @@ Requires-Dist: Pillow; extra == "ml"
|
|
|
34
34
|
Requires-Dist: evotorch; extra == "ml"
|
|
35
35
|
Requires-Dist: pyarrow; extra == "ml"
|
|
36
36
|
Requires-Dist: colorlog; extra == "ml"
|
|
37
|
+
Requires-Dist: torchmetrics; extra == "ml"
|
|
37
38
|
Provides-Extra: mice
|
|
38
39
|
Requires-Dist: numpy<2.0; extra == "mice"
|
|
39
40
|
Requires-Dist: pandas; extra == "mice"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dragon-ml-toolbox
|
|
3
|
-
Version:
|
|
3
|
+
Version: 14.0.0
|
|
4
4
|
Summary: A collection of tools for data science and machine learning projects.
|
|
5
5
|
Author-email: "Karl L. Loza Vidaurre" <luigiloza@gmail.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -34,6 +34,7 @@ Requires-Dist: Pillow; extra == "ml"
|
|
|
34
34
|
Requires-Dist: evotorch; extra == "ml"
|
|
35
35
|
Requires-Dist: pyarrow; extra == "ml"
|
|
36
36
|
Requires-Dist: colorlog; extra == "ml"
|
|
37
|
+
Requires-Dist: torchmetrics; extra == "ml"
|
|
37
38
|
Provides-Extra: mice
|
|
38
39
|
Requires-Dist: numpy<2.0; extra == "mice"
|
|
39
40
|
Requires-Dist: pandas; extra == "mice"
|
{dragon_ml_toolbox-13.7.0 → dragon_ml_toolbox-14.0.0}/dragon_ml_toolbox.egg-info/SOURCES.txt
RENAMED
|
@@ -21,10 +21,17 @@ ml_tools/ML_optimization.py
|
|
|
21
21
|
ml_tools/ML_scaler.py
|
|
22
22
|
ml_tools/ML_trainer.py
|
|
23
23
|
ml_tools/ML_utilities.py
|
|
24
|
+
ml_tools/ML_vision_datasetmaster.py
|
|
25
|
+
ml_tools/ML_vision_evaluation.py
|
|
26
|
+
ml_tools/ML_vision_inference.py
|
|
27
|
+
ml_tools/ML_vision_models.py
|
|
28
|
+
ml_tools/ML_vision_transformers.py
|
|
24
29
|
ml_tools/PSO_optimization.py
|
|
25
30
|
ml_tools/RNN_forecast.py
|
|
26
31
|
ml_tools/SQL.py
|
|
27
32
|
ml_tools/VIF_factor.py
|
|
33
|
+
ml_tools/_ML_pytorch_tabular.py
|
|
34
|
+
ml_tools/_ML_vision_recipe.py
|
|
28
35
|
ml_tools/__init__.py
|
|
29
36
|
ml_tools/_logger.py
|
|
30
37
|
ml_tools/_schema.py
|
|
@@ -1,13 +1,10 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
from torch.utils.data import Dataset
|
|
2
|
+
from torch.utils.data import Dataset
|
|
3
3
|
import pandas
|
|
4
4
|
import numpy
|
|
5
5
|
from sklearn.model_selection import train_test_split
|
|
6
6
|
from typing import Literal, Union, Tuple, List, Optional
|
|
7
7
|
from abc import ABC, abstractmethod
|
|
8
|
-
from PIL import Image, ImageOps
|
|
9
|
-
from torchvision.datasets import ImageFolder
|
|
10
|
-
from torchvision import transforms
|
|
11
8
|
import matplotlib.pyplot as plt
|
|
12
9
|
from pathlib import Path
|
|
13
10
|
|
|
@@ -23,9 +20,7 @@ from ._schema import FeatureSchema
|
|
|
23
20
|
__all__ = [
|
|
24
21
|
"DatasetMaker",
|
|
25
22
|
"DatasetMakerMulti",
|
|
26
|
-
"
|
|
27
|
-
"SequenceMaker",
|
|
28
|
-
"ResizeAspectFill",
|
|
23
|
+
"SequenceMaker"
|
|
29
24
|
]
|
|
30
25
|
|
|
31
26
|
|
|
@@ -473,149 +468,6 @@ class _BaseMaker(ABC):
|
|
|
473
468
|
pass
|
|
474
469
|
|
|
475
470
|
|
|
476
|
-
# --- VisionDatasetMaker ---
|
|
477
|
-
class VisionDatasetMaker(_BaseMaker):
|
|
478
|
-
"""
|
|
479
|
-
Creates processed PyTorch datasets for computer vision tasks from an
|
|
480
|
-
image folder directory.
|
|
481
|
-
|
|
482
|
-
Uses online augmentations per epoch (image augmentation without creating new files).
|
|
483
|
-
"""
|
|
484
|
-
def __init__(self, full_dataset: ImageFolder):
|
|
485
|
-
super().__init__()
|
|
486
|
-
self.full_dataset = full_dataset
|
|
487
|
-
self.labels = [s[1] for s in self.full_dataset.samples]
|
|
488
|
-
self.class_map = full_dataset.class_to_idx
|
|
489
|
-
|
|
490
|
-
self._is_split = False
|
|
491
|
-
self._are_transforms_configured = False
|
|
492
|
-
|
|
493
|
-
@classmethod
|
|
494
|
-
def from_folder(cls, root_dir: str) -> 'VisionDatasetMaker':
|
|
495
|
-
"""Creates a maker instance from a root directory of images."""
|
|
496
|
-
initial_transform = transforms.Compose([transforms.ToTensor()])
|
|
497
|
-
full_dataset = ImageFolder(root=root_dir, transform=initial_transform)
|
|
498
|
-
_LOGGER.info(f"Found {len(full_dataset)} images in {len(full_dataset.classes)} classes.")
|
|
499
|
-
return cls(full_dataset)
|
|
500
|
-
|
|
501
|
-
@staticmethod
|
|
502
|
-
def inspect_folder(path: Union[str, Path]):
|
|
503
|
-
"""
|
|
504
|
-
Logs a report of the types, sizes, and channels of image files
|
|
505
|
-
found in the directory and its subdirectories.
|
|
506
|
-
"""
|
|
507
|
-
path_obj = make_fullpath(path)
|
|
508
|
-
|
|
509
|
-
non_image_files = set()
|
|
510
|
-
img_types = set()
|
|
511
|
-
img_sizes = set()
|
|
512
|
-
img_channels = set()
|
|
513
|
-
img_counter = 0
|
|
514
|
-
|
|
515
|
-
_LOGGER.info(f"Inspecting folder: {path_obj}...")
|
|
516
|
-
# Use rglob to recursively find all files
|
|
517
|
-
for filepath in path_obj.rglob('*'):
|
|
518
|
-
if filepath.is_file():
|
|
519
|
-
try:
|
|
520
|
-
# Using PIL to open is a more reliable check
|
|
521
|
-
with Image.open(filepath) as img:
|
|
522
|
-
img_types.add(img.format)
|
|
523
|
-
img_sizes.add(img.size)
|
|
524
|
-
img_channels.update(img.getbands())
|
|
525
|
-
img_counter += 1
|
|
526
|
-
except (IOError, SyntaxError):
|
|
527
|
-
non_image_files.add(filepath.name)
|
|
528
|
-
|
|
529
|
-
if non_image_files:
|
|
530
|
-
_LOGGER.warning(f"Non-image or corrupted files found and ignored: {non_image_files}")
|
|
531
|
-
|
|
532
|
-
report = (
|
|
533
|
-
f"\n--- Inspection Report for '{path_obj.name}' ---\n"
|
|
534
|
-
f"Total images found: {img_counter}\n"
|
|
535
|
-
f"Image formats: {img_types or 'None'}\n"
|
|
536
|
-
f"Image sizes (WxH): {img_sizes or 'None'}\n"
|
|
537
|
-
f"Image channels (bands): {img_channels or 'None'}\n"
|
|
538
|
-
f"--------------------------------------"
|
|
539
|
-
)
|
|
540
|
-
print(report)
|
|
541
|
-
|
|
542
|
-
def split_data(self, val_size: float = 0.2, test_size: float = 0.0,
|
|
543
|
-
stratify: bool = True, random_state: Optional[int] = None) -> 'VisionDatasetMaker':
|
|
544
|
-
"""Splits the dataset into training, validation, and optional test sets."""
|
|
545
|
-
if self._is_split:
|
|
546
|
-
_LOGGER.warning("Data has already been split.")
|
|
547
|
-
return self
|
|
548
|
-
|
|
549
|
-
if val_size + test_size >= 1.0:
|
|
550
|
-
_LOGGER.error("The sum of val_size and test_size must be less than 1.")
|
|
551
|
-
raise ValueError()
|
|
552
|
-
|
|
553
|
-
indices = list(range(len(self.full_dataset)))
|
|
554
|
-
labels_for_split = self.labels if stratify else None
|
|
555
|
-
|
|
556
|
-
train_indices, val_test_indices = train_test_split(
|
|
557
|
-
indices, test_size=(val_size + test_size), random_state=random_state, stratify=labels_for_split
|
|
558
|
-
)
|
|
559
|
-
|
|
560
|
-
if test_size > 0:
|
|
561
|
-
val_test_labels = [self.labels[i] for i in val_test_indices]
|
|
562
|
-
stratify_val_test = val_test_labels if stratify else None
|
|
563
|
-
val_indices, test_indices = train_test_split(
|
|
564
|
-
val_test_indices, test_size=(test_size / (val_size + test_size)),
|
|
565
|
-
random_state=random_state, stratify=stratify_val_test
|
|
566
|
-
)
|
|
567
|
-
self._test_dataset = Subset(self.full_dataset, test_indices)
|
|
568
|
-
_LOGGER.info(f"Test set created with {len(self._test_dataset)} images.")
|
|
569
|
-
else:
|
|
570
|
-
val_indices = val_test_indices
|
|
571
|
-
|
|
572
|
-
self._train_dataset = Subset(self.full_dataset, train_indices)
|
|
573
|
-
self._val_dataset = Subset(self.full_dataset, val_indices)
|
|
574
|
-
self._is_split = True
|
|
575
|
-
|
|
576
|
-
_LOGGER.info(f"Data split into: \n- Training: {len(self._train_dataset)} images \n- Validation: {len(self._val_dataset)} images")
|
|
577
|
-
return self
|
|
578
|
-
|
|
579
|
-
def configure_transforms(self, resize_size: int = 256, crop_size: int = 224,
|
|
580
|
-
mean: List[float] = [0.485, 0.456, 0.406],
|
|
581
|
-
std: List[float] = [0.229, 0.224, 0.225],
|
|
582
|
-
extra_train_transforms: Optional[List] = None) -> 'VisionDatasetMaker':
|
|
583
|
-
"""Configures and applies the image transformations (augmentations)."""
|
|
584
|
-
if not self._is_split:
|
|
585
|
-
_LOGGER.error("Transforms must be configured AFTER splitting data. Call .split_data() first.")
|
|
586
|
-
raise RuntimeError()
|
|
587
|
-
|
|
588
|
-
base_train_transforms = [transforms.RandomResizedCrop(crop_size), transforms.RandomHorizontalFlip()]
|
|
589
|
-
if extra_train_transforms:
|
|
590
|
-
base_train_transforms.extend(extra_train_transforms)
|
|
591
|
-
|
|
592
|
-
final_transforms = [transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]
|
|
593
|
-
|
|
594
|
-
val_transform = transforms.Compose([transforms.Resize(resize_size), transforms.CenterCrop(crop_size), *final_transforms])
|
|
595
|
-
train_transform = transforms.Compose([*base_train_transforms, *final_transforms])
|
|
596
|
-
|
|
597
|
-
self._train_dataset.dataset.transform = train_transform # type: ignore
|
|
598
|
-
self._val_dataset.dataset.transform = val_transform # type: ignore
|
|
599
|
-
if self._test_dataset:
|
|
600
|
-
self._test_dataset.dataset.transform = val_transform # type: ignore
|
|
601
|
-
|
|
602
|
-
self._are_transforms_configured = True
|
|
603
|
-
_LOGGER.info("Image transforms configured and applied.")
|
|
604
|
-
return self
|
|
605
|
-
|
|
606
|
-
def get_datasets(self) -> Tuple[Dataset, ...]:
|
|
607
|
-
"""Returns the final train, validation, and optional test datasets."""
|
|
608
|
-
if not self._is_split:
|
|
609
|
-
_LOGGER.error("Data has not been split. Call .split_data() first.")
|
|
610
|
-
raise RuntimeError()
|
|
611
|
-
if not self._are_transforms_configured:
|
|
612
|
-
_LOGGER.warning("Transforms have not been configured. Using default ToTensor only.")
|
|
613
|
-
|
|
614
|
-
if self._test_dataset:
|
|
615
|
-
return self._train_dataset, self._val_dataset, self._test_dataset
|
|
616
|
-
return self._train_dataset, self._val_dataset
|
|
617
|
-
|
|
618
|
-
|
|
619
471
|
# --- SequenceMaker ---
|
|
620
472
|
class SequenceMaker(_BaseMaker):
|
|
621
473
|
"""
|
|
@@ -804,40 +656,5 @@ class SequenceMaker(_BaseMaker):
|
|
|
804
656
|
return self._train_dataset, self._test_dataset
|
|
805
657
|
|
|
806
658
|
|
|
807
|
-
# --- Custom Vision Transform Class ---
|
|
808
|
-
class ResizeAspectFill:
|
|
809
|
-
"""
|
|
810
|
-
Custom transformation to make an image square by padding it to match the
|
|
811
|
-
longest side, preserving the aspect ratio. The image is finally centered.
|
|
812
|
-
|
|
813
|
-
Args:
|
|
814
|
-
pad_color (Union[str, int]): Color to use for the padding.
|
|
815
|
-
Defaults to "black".
|
|
816
|
-
"""
|
|
817
|
-
def __init__(self, pad_color: Union[str, int] = "black") -> None:
|
|
818
|
-
self.pad_color = pad_color
|
|
819
|
-
|
|
820
|
-
def __call__(self, image: Image.Image) -> Image.Image:
|
|
821
|
-
if not isinstance(image, Image.Image):
|
|
822
|
-
_LOGGER.error(f"Expected PIL.Image.Image, got {type(image).__name__}")
|
|
823
|
-
raise TypeError()
|
|
824
|
-
|
|
825
|
-
w, h = image.size
|
|
826
|
-
if w == h:
|
|
827
|
-
return image
|
|
828
|
-
|
|
829
|
-
# Determine padding to center the image
|
|
830
|
-
if w > h:
|
|
831
|
-
top_padding = (w - h) // 2
|
|
832
|
-
bottom_padding = w - h - top_padding
|
|
833
|
-
padding = (0, top_padding, 0, bottom_padding)
|
|
834
|
-
else: # h > w
|
|
835
|
-
left_padding = (h - w) // 2
|
|
836
|
-
right_padding = h - w - left_padding
|
|
837
|
-
padding = (left_padding, 0, right_padding, 0)
|
|
838
|
-
|
|
839
|
-
return ImageOps.expand(image, padding, fill=self.pad_color)
|
|
840
|
-
|
|
841
|
-
|
|
842
659
|
def info():
|
|
843
660
|
_script_info(__all__)
|
|
@@ -24,7 +24,7 @@ import warnings
|
|
|
24
24
|
from .path_manager import make_fullpath
|
|
25
25
|
from ._logger import _LOGGER
|
|
26
26
|
from ._script_info import _script_info
|
|
27
|
-
from .keys import SHAPKeys
|
|
27
|
+
from .keys import SHAPKeys, PyTorchLogKeys
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
__all__ = [
|
|
@@ -44,8 +44,8 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
|
44
44
|
history (dict): A dictionary containing 'train_loss' and 'val_loss'.
|
|
45
45
|
save_dir (str | Path): Directory to save the plot image.
|
|
46
46
|
"""
|
|
47
|
-
train_loss = history.get(
|
|
48
|
-
val_loss = history.get(
|
|
47
|
+
train_loss = history.get(PyTorchLogKeys.TRAIN_LOSS, [])
|
|
48
|
+
val_loss = history.get(PyTorchLogKeys.VAL_LOSS, [])
|
|
49
49
|
|
|
50
50
|
if not train_loss and not val_loss:
|
|
51
51
|
print("Warning: Loss history is empty or incomplete. Cannot plot.")
|
|
@@ -82,7 +82,6 @@ class _BaseInferenceHandler(ABC):
|
|
|
82
82
|
_LOGGER.warning("CUDA not available, switching to CPU.")
|
|
83
83
|
device_lower = "cpu"
|
|
84
84
|
elif device_lower == "mps" and not torch.backends.mps.is_available():
|
|
85
|
-
# Your M-series Mac will appreciate this check!
|
|
86
85
|
_LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
|
|
87
86
|
device_lower = "cpu"
|
|
88
87
|
return torch.device(device_lower)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from torch import nn
|
|
3
|
-
from typing import List, Union, Tuple, Dict, Any
|
|
3
|
+
from typing import List, Union, Tuple, Dict, Any, Literal, Optional
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
import json
|
|
6
6
|
|
|
@@ -748,5 +748,7 @@ class SequencePredictorLSTM(nn.Module, _ArchitectureHandlerMixin):
|
|
|
748
748
|
)
|
|
749
749
|
|
|
750
750
|
|
|
751
|
+
# ---- PyTorch models ---
|
|
752
|
+
|
|
751
753
|
def info():
|
|
752
754
|
_script_info(__all__)
|