dragon-ml-toolbox 10.5.0__py3-none-any.whl → 10.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-10.5.0.dist-info → dragon_ml_toolbox-10.7.0.dist-info}/METADATA +1 -1
- {dragon_ml_toolbox-10.5.0.dist-info → dragon_ml_toolbox-10.7.0.dist-info}/RECORD +11 -11
- ml_tools/ML_datasetmaster.py +25 -5
- ml_tools/ML_models.py +63 -81
- ml_tools/ML_scaler.py +1 -1
- ml_tools/ML_trainer.py +3 -7
- ml_tools/keys.py +7 -0
- {dragon_ml_toolbox-10.5.0.dist-info → dragon_ml_toolbox-10.7.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-10.5.0.dist-info → dragon_ml_toolbox-10.7.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-10.5.0.dist-info → dragon_ml_toolbox-10.7.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-10.5.0.dist-info → dragon_ml_toolbox-10.7.0.dist-info}/top_level.txt +0 -0
|
@@ -1,18 +1,18 @@
|
|
|
1
|
-
dragon_ml_toolbox-10.
|
|
2
|
-
dragon_ml_toolbox-10.
|
|
1
|
+
dragon_ml_toolbox-10.7.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
|
|
2
|
+
dragon_ml_toolbox-10.7.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=lY4_rJPnLnMu7YBQaY-_iz1JRDcLdQzNCyeLAF1glJY,1837
|
|
3
3
|
ml_tools/ETL_cleaning.py,sha256=lSP5q6-ukGhJBPV8dlsqJvPXAzj4du_0J-SbtEd0Pjg,19292
|
|
4
4
|
ml_tools/ETL_engineering.py,sha256=a6KCWH6kRatZtjaFEF_o917ApPMK5_vRD-BjfCDAl-E,49400
|
|
5
5
|
ml_tools/GUI_tools.py,sha256=kEQWg-bog3pB5tI22gMGKWaCGHnz9TB2Lvvfhf5F2CI,45412
|
|
6
6
|
ml_tools/MICE_imputation.py,sha256=kVSythWfxJFR4-2mtcYCWQaQ1Oz5yyx_SJu5gjnS7H8,11670
|
|
7
7
|
ml_tools/ML_callbacks.py,sha256=JPvEw_cW5tYNJ2rMSgnNrKLuni_UrmuhDFaOw-u2SvA,13926
|
|
8
|
-
ml_tools/ML_datasetmaster.py,sha256=
|
|
8
|
+
ml_tools/ML_datasetmaster.py,sha256=uenjHP-Mh4tn20rWSEGN_JsCPvuPNDGW-PElBhb2a4I,30346
|
|
9
9
|
ml_tools/ML_evaluation.py,sha256=28JJ2M71p4pxniwav2Hv3b1a5dsvaoIYNLm-UJQuXvY,16002
|
|
10
10
|
ml_tools/ML_evaluation_multi.py,sha256=2jTSNFCu8cz5C05EusnrDyffs59M2Fq3UXSHxo2TR1A,12515
|
|
11
11
|
ml_tools/ML_inference.py,sha256=SGDPiPxs_OYDKKRZziFMyaWcC8A37c70W9t-dMP5niI,23066
|
|
12
|
-
ml_tools/ML_models.py,sha256=
|
|
12
|
+
ml_tools/ML_models.py,sha256=A_yeULMxT3IAuJuwIF5nXdAQwQDGsxHlbDSxtlzVG44,27699
|
|
13
13
|
ml_tools/ML_optimization.py,sha256=a2Uxe1g-y4I-gFa8ENIM8QDS-Pz3hoPRRaVXAWMbyQA,13491
|
|
14
|
-
ml_tools/ML_scaler.py,sha256=
|
|
15
|
-
ml_tools/ML_trainer.py,sha256=
|
|
14
|
+
ml_tools/ML_scaler.py,sha256=yKVrXW6dWV6UoC9ViLMzORfXQXvGTJvzkNbSrB0F5t0,7447
|
|
15
|
+
ml_tools/ML_trainer.py,sha256=xw1zMgYpdqwsTt604xe3GTQNvpg6z6Ze-avmitGBFeU,23539
|
|
16
16
|
ml_tools/PSO_optimization.py,sha256=q0VYpssQGbPum7xdnkDXlJQKhZMYZo8acHpKhajPK3c,22954
|
|
17
17
|
ml_tools/RNN_forecast.py,sha256=8rNZr-eWOBXMiDQV22e_tQTPM5LM2IFggEAa1FaoXaI,1965
|
|
18
18
|
ml_tools/SQL.py,sha256=WDgdZUYuLBUpv-4Am9XjVY_Aq_jxBWdLrbcgAIEwefI,10704
|
|
@@ -26,11 +26,11 @@ ml_tools/ensemble_evaluation.py,sha256=xMEMfXJ5MjTkTfr1LkFOeD7iUtnVDCW3S9lm3zT-6
|
|
|
26
26
|
ml_tools/ensemble_inference.py,sha256=EFHnbjbu31fcVp88NBx8lWAVdu2Gpg9MY9huVZJHFfM,9350
|
|
27
27
|
ml_tools/ensemble_learning.py,sha256=3s0kH4i_naj0IVl_T4knst-Hwg4TScWjEdsXX5KAi7I,21929
|
|
28
28
|
ml_tools/handle_excel.py,sha256=He4UT15sCGhaG-JKfs7uYVAubxWjrqgJ6U7OhMR2fuE,14005
|
|
29
|
-
ml_tools/keys.py,sha256=
|
|
29
|
+
ml_tools/keys.py,sha256=ThuyNnSV4iK712WRaGXEm9uGW8Dg3djKa7HFRmPCRr4,1228
|
|
30
30
|
ml_tools/optimization_tools.py,sha256=P3I6lIpvZ8Xf2kX5FvvBKBmrK2pB6idBpkTzfUJxTeE,5073
|
|
31
31
|
ml_tools/path_manager.py,sha256=7sRvAoNrboRY6ef9gH3_qdzoZ66iLs7Aii4P39K0kEk,13819
|
|
32
32
|
ml_tools/utilities.py,sha256=SVMaSDigh6SUoAeig2_sXLLIj5w5mUs5KuVWpHvFDec,19816
|
|
33
|
-
dragon_ml_toolbox-10.
|
|
34
|
-
dragon_ml_toolbox-10.
|
|
35
|
-
dragon_ml_toolbox-10.
|
|
36
|
-
dragon_ml_toolbox-10.
|
|
33
|
+
dragon_ml_toolbox-10.7.0.dist-info/METADATA,sha256=3yKY50Qa3kt1lvDo_ELk3dUczIunDGuf6bB3UaiBl9g,6968
|
|
34
|
+
dragon_ml_toolbox-10.7.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
35
|
+
dragon_ml_toolbox-10.7.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
|
|
36
|
+
dragon_ml_toolbox-10.7.0.dist-info/RECORD,,
|
ml_tools/ML_datasetmaster.py
CHANGED
|
@@ -34,7 +34,9 @@ class _PytorchDataset(Dataset):
|
|
|
34
34
|
def __init__(self, features: Union[numpy.ndarray, pandas.DataFrame],
|
|
35
35
|
labels: Union[numpy.ndarray, pandas.Series],
|
|
36
36
|
labels_dtype: torch.dtype,
|
|
37
|
-
features_dtype: torch.dtype = torch.float32
|
|
37
|
+
features_dtype: torch.dtype = torch.float32,
|
|
38
|
+
feature_names: Optional[List[str]] = None,
|
|
39
|
+
target_names: Optional[List[str]] = None):
|
|
38
40
|
"""
|
|
39
41
|
integer labels for classification.
|
|
40
42
|
|
|
@@ -50,12 +52,30 @@ class _PytorchDataset(Dataset):
|
|
|
50
52
|
self.labels = torch.tensor(labels, dtype=labels_dtype)
|
|
51
53
|
else:
|
|
52
54
|
self.labels = torch.tensor(labels.values, dtype=labels_dtype)
|
|
55
|
+
|
|
56
|
+
self._feature_names = feature_names
|
|
57
|
+
self._target_names = target_names
|
|
53
58
|
|
|
54
59
|
def __len__(self):
|
|
55
60
|
return len(self.features)
|
|
56
61
|
|
|
57
62
|
def __getitem__(self, index):
|
|
58
63
|
return self.features[index], self.labels[index]
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def feature_names(self):
|
|
67
|
+
if self._feature_names is not None:
|
|
68
|
+
return self._feature_names
|
|
69
|
+
else:
|
|
70
|
+
_LOGGER.error(f"Dataset {self.__class__} has not been initialized with any feature names.")
|
|
71
|
+
raise ValueError()
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def target_names(self):
|
|
75
|
+
if self._target_names is not None:
|
|
76
|
+
return self._target_names
|
|
77
|
+
else:
|
|
78
|
+
_LOGGER.error(f"Dataset {self.__class__} has not been initialized with any target names.")
|
|
59
79
|
|
|
60
80
|
|
|
61
81
|
# --- Abstract Base Class (New) ---
|
|
@@ -229,8 +249,8 @@ class DatasetMaker(_BaseDatasetMaker):
|
|
|
229
249
|
)
|
|
230
250
|
|
|
231
251
|
# --- 4. Create Datasets ---
|
|
232
|
-
self._train_ds = _PytorchDataset(X_train_final, y_train.values, label_dtype)
|
|
233
|
-
self._test_ds = _PytorchDataset(X_test_final, y_test.values, label_dtype)
|
|
252
|
+
self._train_ds = _PytorchDataset(X_train_final, y_train.values, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=[self._target_name])
|
|
253
|
+
self._test_ds = _PytorchDataset(X_test_final, y_test.values, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=[self._target_name])
|
|
234
254
|
|
|
235
255
|
@property
|
|
236
256
|
def target_name(self) -> str:
|
|
@@ -280,8 +300,8 @@ class DatasetMakerMulti(_BaseDatasetMaker):
|
|
|
280
300
|
X_train, y_train, X_test, label_dtype, continuous_feature_columns
|
|
281
301
|
)
|
|
282
302
|
|
|
283
|
-
self._train_ds = _PytorchDataset(X_train_final, y_train, label_dtype)
|
|
284
|
-
self._test_ds = _PytorchDataset(X_test_final, y_test, label_dtype)
|
|
303
|
+
self._train_ds = _PytorchDataset(X_train_final, y_train, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=self._target_names)
|
|
304
|
+
self._test_ds = _PytorchDataset(X_test_final, y_test, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=self._target_names)
|
|
285
305
|
|
|
286
306
|
@property
|
|
287
307
|
def target_names(self) -> list[str]:
|
ml_tools/ML_models.py
CHANGED
|
@@ -6,6 +6,8 @@ import json
|
|
|
6
6
|
from ._logger import _LOGGER
|
|
7
7
|
from .path_manager import make_fullpath
|
|
8
8
|
from ._script_info import _script_info
|
|
9
|
+
from .keys import PytorchModelKeys
|
|
10
|
+
|
|
9
11
|
|
|
10
12
|
__all__ = [
|
|
11
13
|
"MultilayerPerceptron",
|
|
@@ -13,12 +15,63 @@ __all__ = [
|
|
|
13
15
|
"MultiHeadAttentionMLP",
|
|
14
16
|
"TabularTransformer",
|
|
15
17
|
"SequencePredictorLSTM",
|
|
16
|
-
"save_architecture",
|
|
17
|
-
"load_architecture"
|
|
18
18
|
]
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
class
|
|
21
|
+
class _ArchitectureHandlerMixin:
|
|
22
|
+
"""
|
|
23
|
+
A mixin class to provide save and load functionality for model architectures.
|
|
24
|
+
"""
|
|
25
|
+
def save(self: nn.Module, directory: Union[str, Path], verbose: bool = True): # type: ignore
|
|
26
|
+
"""Saves the model's architecture to a JSON file."""
|
|
27
|
+
if not hasattr(self, 'get_architecture_config'):
|
|
28
|
+
_LOGGER.error(f"Model '{self.__class__.__name__}' must have a 'get_architecture_config()' method to use this functionality.")
|
|
29
|
+
raise AttributeError()
|
|
30
|
+
|
|
31
|
+
path_dir = make_fullpath(directory, make=True, enforce="directory")
|
|
32
|
+
full_path = path_dir / PytorchModelKeys.SAVENAME
|
|
33
|
+
|
|
34
|
+
config = {
|
|
35
|
+
PytorchModelKeys.MODEL: self.__class__.__name__,
|
|
36
|
+
PytorchModelKeys.CONFIG: self.get_architecture_config() # type: ignore
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
with open(full_path, 'w') as f:
|
|
40
|
+
json.dump(config, f, indent=4)
|
|
41
|
+
|
|
42
|
+
if verbose:
|
|
43
|
+
_LOGGER.info(f"Architecture for '{self.__class__.__name__}' saved to '{path_dir.name}'")
|
|
44
|
+
|
|
45
|
+
@classmethod
|
|
46
|
+
def load(cls: type, file_or_dir: Union[str, Path], verbose: bool = True) -> nn.Module:
|
|
47
|
+
"""Loads a model architecture from a JSON file. If a directory is provided, the function will attempt to load a JSON file inside."""
|
|
48
|
+
user_path = make_fullpath(file_or_dir)
|
|
49
|
+
|
|
50
|
+
if user_path.is_dir():
|
|
51
|
+
target_path = make_fullpath(user_path / PytorchModelKeys.SAVENAME, enforce="file")
|
|
52
|
+
elif user_path.is_file():
|
|
53
|
+
target_path = user_path
|
|
54
|
+
else:
|
|
55
|
+
_LOGGER.error(f"Invalid path: '{file_or_dir}'")
|
|
56
|
+
raise IOError()
|
|
57
|
+
|
|
58
|
+
with open(target_path, 'r') as f:
|
|
59
|
+
saved_data = json.load(f)
|
|
60
|
+
|
|
61
|
+
saved_class_name = saved_data[PytorchModelKeys.MODEL]
|
|
62
|
+
config = saved_data[PytorchModelKeys.CONFIG]
|
|
63
|
+
|
|
64
|
+
if saved_class_name != cls.__name__:
|
|
65
|
+
_LOGGER.error(f"Model class mismatch. File specifies '{saved_class_name}', but '{cls.__name__}' was expected.")
|
|
66
|
+
raise ValueError()
|
|
67
|
+
|
|
68
|
+
model = cls(**config)
|
|
69
|
+
if verbose:
|
|
70
|
+
_LOGGER.info(f"Successfully loaded architecture for '{saved_class_name}'")
|
|
71
|
+
return model
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class _BaseMLP(nn.Module, _ArchitectureHandlerMixin):
|
|
22
75
|
"""
|
|
23
76
|
A base class for Multilayer Perceptrons.
|
|
24
77
|
|
|
@@ -68,7 +121,7 @@ class _BaseMLP(nn.Module):
|
|
|
68
121
|
# Set a customizable Prediction Head for flexibility, specially in transfer learning and fine-tuning
|
|
69
122
|
self.output_layer = nn.Linear(current_features, out_targets)
|
|
70
123
|
|
|
71
|
-
def
|
|
124
|
+
def get_architecture_config(self) -> Dict[str, Any]:
|
|
72
125
|
"""Returns the base configuration of the model."""
|
|
73
126
|
return {
|
|
74
127
|
'in_features': self.in_features,
|
|
@@ -228,9 +281,9 @@ class MultiHeadAttentionMLP(_BaseMLP):
|
|
|
228
281
|
|
|
229
282
|
return logits, attention_weights
|
|
230
283
|
|
|
231
|
-
def
|
|
284
|
+
def get_architecture_config(self) -> Dict[str, Any]:
|
|
232
285
|
"""Returns the full configuration of the model."""
|
|
233
|
-
config = super().
|
|
286
|
+
config = super().get_architecture_config()
|
|
234
287
|
config['num_heads'] = self.num_heads
|
|
235
288
|
config['attention_dropout'] = self.attention_dropout
|
|
236
289
|
return config
|
|
@@ -247,7 +300,7 @@ class MultiHeadAttentionMLP(_BaseMLP):
|
|
|
247
300
|
return f"MultiHeadAttentionMLP(arch: {arch_str})"
|
|
248
301
|
|
|
249
302
|
|
|
250
|
-
class TabularTransformer(nn.Module):
|
|
303
|
+
class TabularTransformer(nn.Module, _ArchitectureHandlerMixin):
|
|
251
304
|
"""
|
|
252
305
|
A Transformer-based model for tabular data tasks.
|
|
253
306
|
|
|
@@ -357,7 +410,7 @@ class TabularTransformer(nn.Module):
|
|
|
357
410
|
|
|
358
411
|
return logits
|
|
359
412
|
|
|
360
|
-
def
|
|
413
|
+
def get_architecture_config(self) -> Dict[str, Any]:
|
|
361
414
|
"""Returns the full configuration of the model."""
|
|
362
415
|
return {
|
|
363
416
|
'out_targets': self.out_targets,
|
|
@@ -529,7 +582,7 @@ class _MultiHeadAttentionLayer(nn.Module):
|
|
|
529
582
|
return out, attn_weights.squeeze()
|
|
530
583
|
|
|
531
584
|
|
|
532
|
-
class SequencePredictorLSTM(nn.Module):
|
|
585
|
+
class SequencePredictorLSTM(nn.Module, _ArchitectureHandlerMixin):
|
|
533
586
|
"""
|
|
534
587
|
A simple LSTM-based network for sequence-to-sequence prediction tasks.
|
|
535
588
|
|
|
@@ -597,7 +650,7 @@ class SequencePredictorLSTM(nn.Module):
|
|
|
597
650
|
|
|
598
651
|
return predictions
|
|
599
652
|
|
|
600
|
-
def
|
|
653
|
+
def get_architecture_config(self) -> dict:
|
|
601
654
|
"""Returns the configuration of the model."""
|
|
602
655
|
return {
|
|
603
656
|
'features': self.features,
|
|
@@ -615,76 +668,5 @@ class SequencePredictorLSTM(nn.Module):
|
|
|
615
668
|
)
|
|
616
669
|
|
|
617
670
|
|
|
618
|
-
def save_architecture(model: nn.Module, directory: Union[str, Path], verbose: bool=True):
|
|
619
|
-
"""
|
|
620
|
-
Saves a model's architecture to a 'architecture.json' file.
|
|
621
|
-
|
|
622
|
-
This function relies on the model having a `get_config()` method that
|
|
623
|
-
returns a dictionary of the arguments needed to initialize it.
|
|
624
|
-
|
|
625
|
-
Args:
|
|
626
|
-
model (nn.Module): The PyTorch model instance to save.
|
|
627
|
-
directory (str | Path): The directory to save the JSON file.
|
|
628
|
-
|
|
629
|
-
Raises:
|
|
630
|
-
AttributeError: If the model does not have a `get_config()` method.
|
|
631
|
-
"""
|
|
632
|
-
if not hasattr(model, 'get_config'):
|
|
633
|
-
_LOGGER.error(f"Model '{model.__class__.__name__}' does not have a 'get_config()' method.")
|
|
634
|
-
raise AttributeError()
|
|
635
|
-
|
|
636
|
-
# Ensure the target directory exists
|
|
637
|
-
path_dir = make_fullpath(directory, make=True, enforce="directory")
|
|
638
|
-
full_path = path_dir / "architecture.json"
|
|
639
|
-
|
|
640
|
-
config = {
|
|
641
|
-
'model_class': model.__class__.__name__,
|
|
642
|
-
'config': model.get_config() # type: ignore
|
|
643
|
-
}
|
|
644
|
-
|
|
645
|
-
with open(full_path, 'w') as f:
|
|
646
|
-
json.dump(config, f, indent=4)
|
|
647
|
-
|
|
648
|
-
if verbose:
|
|
649
|
-
_LOGGER.info(f"Architecture for '{model.__class__.__name__}' saved to '{path_dir.name}'")
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
def load_architecture(filepath: Union[str, Path], expected_model_class: type, verbose: bool=True) -> nn.Module:
|
|
653
|
-
"""
|
|
654
|
-
Loads a model architecture from a JSON file.
|
|
655
|
-
|
|
656
|
-
This function instantiates a model by providing an explicit class to use
|
|
657
|
-
and checking that it matches the class name specified in the file.
|
|
658
|
-
|
|
659
|
-
Args:
|
|
660
|
-
filepath (Union[str, Path]): The path of the JSON architecture file.
|
|
661
|
-
expected_model_class (type): The model class expected to load (e.g., MultilayerPerceptron).
|
|
662
|
-
|
|
663
|
-
Returns:
|
|
664
|
-
nn.Module: An instance of the model with a freshly initialized state.
|
|
665
|
-
|
|
666
|
-
Raises:
|
|
667
|
-
FileNotFoundError: If the filepath does not exist.
|
|
668
|
-
ValueError: If the class name in the file does not match the `expected_model_class`.
|
|
669
|
-
"""
|
|
670
|
-
path_obj = make_fullpath(filepath, enforce="file")
|
|
671
|
-
|
|
672
|
-
with open(path_obj, 'r') as f:
|
|
673
|
-
saved_data = json.load(f)
|
|
674
|
-
|
|
675
|
-
saved_class_name = saved_data['model_class']
|
|
676
|
-
config = saved_data['config']
|
|
677
|
-
|
|
678
|
-
if saved_class_name != expected_model_class.__name__:
|
|
679
|
-
_LOGGER.error(f"Model class mismatch. File specifies '{saved_class_name}', but '{expected_model_class.__name__}' was expected.")
|
|
680
|
-
raise ValueError()
|
|
681
|
-
|
|
682
|
-
# Create an instance of the model using the provided class and config
|
|
683
|
-
model = expected_model_class(**config)
|
|
684
|
-
if verbose:
|
|
685
|
-
_LOGGER.info(f"Successfully loaded architecture for '{saved_class_name}'")
|
|
686
|
-
return model
|
|
687
|
-
|
|
688
|
-
|
|
689
671
|
def info():
|
|
690
672
|
_script_info(__all__)
|
ml_tools/ML_scaler.py
CHANGED
|
@@ -156,7 +156,7 @@ class PytorchScaler:
|
|
|
156
156
|
Args:
|
|
157
157
|
filepath (str | Path): The path to save the file.
|
|
158
158
|
"""
|
|
159
|
-
path_obj = make_fullpath(filepath)
|
|
159
|
+
path_obj = make_fullpath(filepath, make=True, enforce="file")
|
|
160
160
|
state = {
|
|
161
161
|
'mean': self.mean_,
|
|
162
162
|
'std': self.std_,
|
ml_tools/ML_trainer.py
CHANGED
|
@@ -357,7 +357,7 @@ class MLTrainer:
|
|
|
357
357
|
If None, the trainer's test dataset is used.
|
|
358
358
|
n_samples (int): The number of samples to use for both background and explanation.
|
|
359
359
|
feature_names (list[str] | None): Feature names.
|
|
360
|
-
target_names (list[str] | None): Target names
|
|
360
|
+
target_names (list[str] | None): Target names for multi-target tasks.
|
|
361
361
|
save_dir (str | Path): Directory to save all SHAP artifacts.
|
|
362
362
|
"""
|
|
363
363
|
# Internal helper to create a dataloader and get a random sample
|
|
@@ -408,12 +408,8 @@ class MLTrainer:
|
|
|
408
408
|
if hasattr(target_dataset, "feature_names"):
|
|
409
409
|
feature_names = target_dataset.feature_names # type: ignore
|
|
410
410
|
else:
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
feature_names = target_dataset.dataset.feature_names # type: ignore
|
|
414
|
-
except AttributeError:
|
|
415
|
-
_LOGGER.error("Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a `feature_names` attribute.")
|
|
416
|
-
raise ValueError()
|
|
411
|
+
_LOGGER.error("Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a `feature_names` attribute.")
|
|
412
|
+
raise ValueError()
|
|
417
413
|
|
|
418
414
|
# 3. Call the plotting function
|
|
419
415
|
if self.kind in ["regression", "classification"]:
|
ml_tools/keys.py
CHANGED
|
@@ -38,6 +38,13 @@ class PyTorchInferenceKeys:
|
|
|
38
38
|
PROBABILITIES = "probabilities"
|
|
39
39
|
|
|
40
40
|
|
|
41
|
+
class PytorchModelKeys:
|
|
42
|
+
"""Keys for saving and loading models"""
|
|
43
|
+
MODEL = 'model_class'
|
|
44
|
+
CONFIG = "config"
|
|
45
|
+
SAVENAME = "architecture.json"
|
|
46
|
+
|
|
47
|
+
|
|
41
48
|
class _OneHotOtherPlaceholder:
|
|
42
49
|
"""Used internally by GUI_tools."""
|
|
43
50
|
OTHER_GUI = "OTHER"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|