dragon-ml-toolbox 12.13.0__py3-none-any.whl → 14.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-12.13.0.dist-info → dragon_ml_toolbox-14.3.0.dist-info}/METADATA +11 -2
- dragon_ml_toolbox-14.3.0.dist-info/RECORD +48 -0
- {dragon_ml_toolbox-12.13.0.dist-info → dragon_ml_toolbox-14.3.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
- ml_tools/MICE_imputation.py +207 -5
- ml_tools/ML_callbacks.py +40 -8
- ml_tools/ML_datasetmaster.py +200 -261
- ml_tools/ML_evaluation.py +29 -17
- ml_tools/ML_evaluation_multi.py +13 -10
- ml_tools/ML_inference.py +14 -5
- ml_tools/ML_models.py +135 -55
- ml_tools/ML_models_advanced.py +323 -0
- ml_tools/ML_optimization.py +49 -36
- ml_tools/ML_trainer.py +560 -30
- ml_tools/ML_utilities.py +302 -4
- ml_tools/ML_vision_datasetmaster.py +1352 -0
- ml_tools/ML_vision_evaluation.py +260 -0
- ml_tools/ML_vision_inference.py +428 -0
- ml_tools/ML_vision_models.py +627 -0
- ml_tools/ML_vision_transformers.py +58 -0
- ml_tools/PSO_optimization.py +5 -1
- ml_tools/_ML_vision_recipe.py +88 -0
- ml_tools/__init__.py +1 -0
- ml_tools/_schema.py +96 -0
- ml_tools/custom_logger.py +37 -14
- ml_tools/data_exploration.py +576 -138
- ml_tools/keys.py +51 -1
- ml_tools/math_utilities.py +1 -1
- ml_tools/optimization_tools.py +65 -86
- ml_tools/serde.py +78 -17
- ml_tools/utilities.py +192 -3
- dragon_ml_toolbox-12.13.0.dist-info/RECORD +0 -41
- ml_tools/ML_simple_optimization.py +0 -413
- {dragon_ml_toolbox-12.13.0.dist-info → dragon_ml_toolbox-14.3.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-12.13.0.dist-info → dragon_ml_toolbox-14.3.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-12.13.0.dist-info → dragon_ml_toolbox-14.3.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_evaluation.py
CHANGED
|
@@ -19,11 +19,12 @@ import torch
|
|
|
19
19
|
import shap
|
|
20
20
|
from pathlib import Path
|
|
21
21
|
from typing import Union, Optional, List, Literal
|
|
22
|
+
import warnings
|
|
22
23
|
|
|
23
24
|
from .path_manager import make_fullpath
|
|
24
25
|
from ._logger import _LOGGER
|
|
25
26
|
from ._script_info import _script_info
|
|
26
|
-
from .keys import SHAPKeys
|
|
27
|
+
from .keys import SHAPKeys, PyTorchLogKeys
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
__all__ = [
|
|
@@ -43,8 +44,8 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
|
43
44
|
history (dict): A dictionary containing 'train_loss' and 'val_loss'.
|
|
44
45
|
save_dir (str | Path): Directory to save the plot image.
|
|
45
46
|
"""
|
|
46
|
-
train_loss = history.get(
|
|
47
|
-
val_loss = history.get(
|
|
47
|
+
train_loss = history.get(PyTorchLogKeys.TRAIN_LOSS, [])
|
|
48
|
+
val_loss = history.get(PyTorchLogKeys.VAL_LOSS, [])
|
|
48
49
|
|
|
49
50
|
if not train_loss and not val_loss:
|
|
50
51
|
print("Warning: Loss history is empty or incomplete. Cannot plot.")
|
|
@@ -257,7 +258,7 @@ def shap_summary_plot(model,
|
|
|
257
258
|
feature_names: Optional[list[str]],
|
|
258
259
|
save_dir: Union[str, Path],
|
|
259
260
|
device: torch.device = torch.device('cpu'),
|
|
260
|
-
explainer_type: Literal['deep', 'kernel'] = '
|
|
261
|
+
explainer_type: Literal['deep', 'kernel'] = 'kernel'):
|
|
261
262
|
"""
|
|
262
263
|
Calculates SHAP values and saves summary plots and data.
|
|
263
264
|
|
|
@@ -269,7 +270,7 @@ def shap_summary_plot(model,
|
|
|
269
270
|
save_dir (str | Path): Directory to save SHAP artifacts.
|
|
270
271
|
device (torch.device): The torch device for SHAP calculations.
|
|
271
272
|
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
272
|
-
- 'deep':
|
|
273
|
+
- 'deep': Uses shap.DeepExplainer. Fast and efficient for
|
|
273
274
|
PyTorch models.
|
|
274
275
|
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY
|
|
275
276
|
slow and memory-intensive.
|
|
@@ -284,7 +285,7 @@ def shap_summary_plot(model,
|
|
|
284
285
|
instances_to_explain_np = None
|
|
285
286
|
|
|
286
287
|
if explainer_type == 'deep':
|
|
287
|
-
# --- 1. Use DeepExplainer
|
|
288
|
+
# --- 1. Use DeepExplainer ---
|
|
288
289
|
|
|
289
290
|
# Ensure data is torch.Tensor
|
|
290
291
|
if isinstance(background_data, np.ndarray):
|
|
@@ -298,17 +299,19 @@ def shap_summary_plot(model,
|
|
|
298
299
|
|
|
299
300
|
background_data = background_data.to(device)
|
|
300
301
|
instances_to_explain = instances_to_explain.to(device)
|
|
301
|
-
|
|
302
|
-
|
|
302
|
+
|
|
303
|
+
with warnings.catch_warnings():
|
|
304
|
+
warnings.simplefilter("ignore", category=UserWarning)
|
|
305
|
+
explainer = shap.DeepExplainer(model, background_data)
|
|
306
|
+
|
|
303
307
|
# print("Calculating SHAP values with DeepExplainer...")
|
|
304
308
|
shap_values = explainer.shap_values(instances_to_explain)
|
|
305
309
|
instances_to_explain_np = instances_to_explain.cpu().numpy()
|
|
306
310
|
|
|
307
311
|
elif explainer_type == 'kernel':
|
|
308
|
-
# --- 2. Use KernelExplainer
|
|
312
|
+
# --- 2. Use KernelExplainer ---
|
|
309
313
|
_LOGGER.warning(
|
|
310
|
-
"
|
|
311
|
-
"Consider reducing 'n_samples' if the process terminates unexpectedly."
|
|
314
|
+
"KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
|
|
312
315
|
)
|
|
313
316
|
|
|
314
317
|
# Ensure data is np.ndarray
|
|
@@ -344,14 +347,26 @@ def shap_summary_plot(model,
|
|
|
344
347
|
else:
|
|
345
348
|
_LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
|
|
346
349
|
raise ValueError()
|
|
350
|
+
|
|
351
|
+
if not isinstance(shap_values, list) and shap_values.ndim == 3 and shap_values.shape[2] == 1:
|
|
352
|
+
# _LOGGER.info("Squeezing SHAP values from (N, F, 1) to (N, F) for regression plot.")
|
|
353
|
+
shap_values = shap_values.squeeze(-1)
|
|
347
354
|
|
|
348
355
|
# --- 3. Plotting and Saving ---
|
|
349
356
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
350
357
|
plt.ioff()
|
|
351
358
|
|
|
359
|
+
# Convert instances to a DataFrame. robust way to ensure SHAP correctly maps values to feature names.
|
|
360
|
+
if feature_names is None:
|
|
361
|
+
# Create generic names if none were provided
|
|
362
|
+
num_features = instances_to_explain_np.shape[1]
|
|
363
|
+
feature_names = [f'feature_{i}' for i in range(num_features)]
|
|
364
|
+
|
|
365
|
+
instances_df = pd.DataFrame(instances_to_explain_np, columns=feature_names)
|
|
366
|
+
|
|
352
367
|
# Save Bar Plot
|
|
353
368
|
bar_path = save_dir_path / "shap_bar_plot.svg"
|
|
354
|
-
shap.summary_plot(shap_values,
|
|
369
|
+
shap.summary_plot(shap_values, instances_df, plot_type="bar", show=False)
|
|
355
370
|
ax = plt.gca()
|
|
356
371
|
ax.set_xlabel("SHAP Value Impact", labelpad=10)
|
|
357
372
|
plt.title("SHAP Feature Importance")
|
|
@@ -362,7 +377,7 @@ def shap_summary_plot(model,
|
|
|
362
377
|
|
|
363
378
|
# Save Dot Plot
|
|
364
379
|
dot_path = save_dir_path / "shap_dot_plot.svg"
|
|
365
|
-
shap.summary_plot(shap_values,
|
|
380
|
+
shap.summary_plot(shap_values, instances_df, plot_type="dot", show=False)
|
|
366
381
|
ax = plt.gca()
|
|
367
382
|
ax.set_xlabel("SHAP Value Impact", labelpad=10)
|
|
368
383
|
if plt.gcf().axes and len(plt.gcf().axes) > 1:
|
|
@@ -385,9 +400,6 @@ def shap_summary_plot(model,
|
|
|
385
400
|
mean_abs_shap = np.abs(shap_values).mean(axis=0)
|
|
386
401
|
|
|
387
402
|
mean_abs_shap = mean_abs_shap.flatten()
|
|
388
|
-
|
|
389
|
-
if feature_names is None:
|
|
390
|
-
feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
|
|
391
403
|
|
|
392
404
|
summary_df = pd.DataFrame({
|
|
393
405
|
SHAPKeys.FEATURE_COLUMN: feature_names,
|
|
@@ -397,7 +409,7 @@ def shap_summary_plot(model,
|
|
|
397
409
|
summary_df.to_csv(summary_path, index=False)
|
|
398
410
|
|
|
399
411
|
_LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
|
|
400
|
-
plt.ion()
|
|
412
|
+
plt.ion()
|
|
401
413
|
|
|
402
414
|
|
|
403
415
|
def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path], top_n: int = 10):
|
ml_tools/ML_evaluation_multi.py
CHANGED
|
@@ -20,6 +20,7 @@ from sklearn.metrics import (
|
|
|
20
20
|
)
|
|
21
21
|
from pathlib import Path
|
|
22
22
|
from typing import Union, List, Literal
|
|
23
|
+
import warnings
|
|
23
24
|
|
|
24
25
|
from .path_manager import make_fullpath, sanitize_filename
|
|
25
26
|
from ._logger import _LOGGER
|
|
@@ -234,7 +235,7 @@ def multi_target_shap_summary_plot(
|
|
|
234
235
|
target_names: List[str],
|
|
235
236
|
save_dir: Union[str, Path],
|
|
236
237
|
device: torch.device = torch.device('cpu'),
|
|
237
|
-
explainer_type: Literal['deep', 'kernel'] = '
|
|
238
|
+
explainer_type: Literal['deep', 'kernel'] = 'kernel'
|
|
238
239
|
):
|
|
239
240
|
"""
|
|
240
241
|
Calculates SHAP values for a multi-target model and saves summary plots and data for each target.
|
|
@@ -248,7 +249,7 @@ def multi_target_shap_summary_plot(
|
|
|
248
249
|
save_dir (str | Path): Directory to save SHAP artifacts.
|
|
249
250
|
device (torch.device): The torch device for SHAP calculations.
|
|
250
251
|
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
251
|
-
- 'deep':
|
|
252
|
+
- 'deep': Uses shap.DeepExplainer. Fast and efficient.
|
|
252
253
|
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but slow and memory-intensive.
|
|
253
254
|
"""
|
|
254
255
|
_LOGGER.info(f"--- Multi-Target SHAP Value Explanation (Using: {explainer_type.upper()}Explainer) ---")
|
|
@@ -259,7 +260,7 @@ def multi_target_shap_summary_plot(
|
|
|
259
260
|
instances_to_explain_np = None
|
|
260
261
|
|
|
261
262
|
if explainer_type == 'deep':
|
|
262
|
-
# --- 1. Use DeepExplainer
|
|
263
|
+
# --- 1. Use DeepExplainer ---
|
|
263
264
|
|
|
264
265
|
# Ensure data is torch.Tensor
|
|
265
266
|
if isinstance(background_data, np.ndarray):
|
|
@@ -273,18 +274,20 @@ def multi_target_shap_summary_plot(
|
|
|
273
274
|
|
|
274
275
|
background_data = background_data.to(device)
|
|
275
276
|
instances_to_explain = instances_to_explain.to(device)
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
277
|
+
|
|
278
|
+
with warnings.catch_warnings():
|
|
279
|
+
warnings.simplefilter("ignore", category=UserWarning)
|
|
280
|
+
explainer = shap.DeepExplainer(model, background_data)
|
|
281
|
+
|
|
282
|
+
# print("Calculating SHAP values with DeepExplainer...")
|
|
279
283
|
# DeepExplainer returns a list of arrays for multi-output models
|
|
280
284
|
shap_values_list = explainer.shap_values(instances_to_explain)
|
|
281
285
|
instances_to_explain_np = instances_to_explain.cpu().numpy()
|
|
282
286
|
|
|
283
287
|
elif explainer_type == 'kernel':
|
|
284
|
-
# --- 2. Use KernelExplainer
|
|
288
|
+
# --- 2. Use KernelExplainer ---
|
|
285
289
|
_LOGGER.warning(
|
|
286
|
-
"
|
|
287
|
-
"Consider reducing 'n_samples' if the process terminates."
|
|
290
|
+
"KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
|
|
288
291
|
)
|
|
289
292
|
|
|
290
293
|
# Convert all data to numpy
|
|
@@ -304,7 +307,7 @@ def multi_target_shap_summary_plot(
|
|
|
304
307
|
return output.cpu().numpy() # Return full multi-output array
|
|
305
308
|
|
|
306
309
|
explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
|
|
307
|
-
print("Calculating SHAP values with KernelExplainer...")
|
|
310
|
+
# print("Calculating SHAP values with KernelExplainer...")
|
|
308
311
|
# KernelExplainer also returns a list of arrays for multi-output models
|
|
309
312
|
shap_values_list = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
|
|
310
313
|
# instances_to_explain_np is already set
|
ml_tools/ML_inference.py
CHANGED
|
@@ -9,7 +9,7 @@ from .ML_scaler import PytorchScaler
|
|
|
9
9
|
from ._script_info import _script_info
|
|
10
10
|
from ._logger import _LOGGER
|
|
11
11
|
from .path_manager import make_fullpath
|
|
12
|
-
from .keys import PyTorchInferenceKeys
|
|
12
|
+
from .keys import PyTorchInferenceKeys, PyTorchCheckpointKeys
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
__all__ = [
|
|
@@ -56,11 +56,21 @@ class _BaseInferenceHandler(ABC):
|
|
|
56
56
|
model_p = make_fullpath(state_dict, enforce="file")
|
|
57
57
|
|
|
58
58
|
try:
|
|
59
|
-
# Load
|
|
60
|
-
|
|
59
|
+
# Load whatever is in the file
|
|
60
|
+
loaded_data = torch.load(model_p, map_location=self.device)
|
|
61
|
+
|
|
62
|
+
# Check if it's the new checkpoint dictionary or an old weights-only file
|
|
63
|
+
if isinstance(loaded_data, dict) and PyTorchCheckpointKeys.MODEL_STATE in loaded_data:
|
|
64
|
+
# It's a new training checkpoint, extract the weights
|
|
65
|
+
self.model.load_state_dict(loaded_data[PyTorchCheckpointKeys.MODEL_STATE])
|
|
66
|
+
else:
|
|
67
|
+
# It's an old-style file (or just a state_dict), load it directly
|
|
68
|
+
self.model.load_state_dict(loaded_data)
|
|
69
|
+
|
|
70
|
+
_LOGGER.info(f"Model state loaded from '{model_p.name}'.")
|
|
71
|
+
|
|
61
72
|
self.model.to(self.device)
|
|
62
73
|
self.model.eval() # Set the model to evaluation mode
|
|
63
|
-
_LOGGER.info(f"Model state loaded from '{model_p.name}' and set to evaluation mode.")
|
|
64
74
|
except Exception as e:
|
|
65
75
|
_LOGGER.error(f"Failed to load model state from '{model_p}': {e}")
|
|
66
76
|
raise
|
|
@@ -72,7 +82,6 @@ class _BaseInferenceHandler(ABC):
|
|
|
72
82
|
_LOGGER.warning("CUDA not available, switching to CPU.")
|
|
73
83
|
device_lower = "cpu"
|
|
74
84
|
elif device_lower == "mps" and not torch.backends.mps.is_available():
|
|
75
|
-
# Your M-series Mac will appreciate this check!
|
|
76
85
|
_LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
|
|
77
86
|
device_lower = "cpu"
|
|
78
87
|
return torch.device(device_lower)
|
ml_tools/ML_models.py
CHANGED
|
@@ -8,6 +8,7 @@ from ._logger import _LOGGER
|
|
|
8
8
|
from .path_manager import make_fullpath
|
|
9
9
|
from ._script_info import _script_info
|
|
10
10
|
from .keys import PytorchModelArchitectureKeys
|
|
11
|
+
from ._schema import FeatureSchema
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
__all__ = [
|
|
@@ -298,76 +299,73 @@ class TabularTransformer(nn.Module, _ArchitectureHandlerMixin):
|
|
|
298
299
|
"""
|
|
299
300
|
A Transformer-based model for tabular data tasks.
|
|
300
301
|
|
|
301
|
-
This model uses a Feature Tokenizer to convert all input features into a
|
|
302
|
+
This model uses a Feature Tokenizer to convert all input features into a
|
|
303
|
+
sequence of embeddings, prepends a [CLS] token, and processes the
|
|
302
304
|
sequence with a standard Transformer Encoder.
|
|
303
305
|
"""
|
|
304
306
|
def __init__(self, *,
|
|
305
|
-
|
|
307
|
+
schema: FeatureSchema,
|
|
306
308
|
out_targets: int,
|
|
307
|
-
|
|
308
|
-
embedding_dim: int = 32,
|
|
309
|
+
embedding_dim: int = 256,
|
|
309
310
|
num_heads: int = 8,
|
|
310
311
|
num_layers: int = 6,
|
|
311
|
-
dropout: float = 0.
|
|
312
|
+
dropout: float = 0.2):
|
|
312
313
|
"""
|
|
313
314
|
Args:
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
315
|
+
schema (FeatureSchema):
|
|
316
|
+
The definitive schema object created by `data_exploration.finalize_feature_schema()`.
|
|
317
|
+
out_targets (int):
|
|
318
|
+
Number of output targets (1 for regression).
|
|
319
|
+
embedding_dim (int):
|
|
320
|
+
The dimension for all feature embeddings. Must be divisible by num_heads. Common values: (64, 128, 192, 256, etc.)
|
|
321
|
+
num_heads (int):
|
|
322
|
+
The number of heads in the multi-head attention mechanism. Common values: (4, 8, 16)
|
|
323
|
+
num_layers (int):
|
|
324
|
+
The number of sub-encoder-layers in the transformer encoder. Common values: (4, 8, 12)
|
|
325
|
+
dropout (float):
|
|
326
|
+
The dropout value.
|
|
327
|
+
|
|
328
|
+
## Note:
|
|
325
329
|
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
**
|
|
331
|
-
|
|
332
|
-
their cardinality (the number of unique categories) via the `categorical_map` parameter.
|
|
333
|
-
|
|
334
|
-
**Ordinal & Binary Features** (e.g., 'Low/Medium/High', 'True/False'): Should be treated as **numerical**. Map them to numbers that
|
|
335
|
-
represent their state (e.g., `{'Low': 0, 'Medium': 1}` or `{False: 0, True: 1}`). Their column indices should **NOT** be included in the
|
|
336
|
-
`categorical_map` parameter.
|
|
330
|
+
**Embedding Dimension:** "Width" of the model. It's the N-dimension vector that will be used to represent each one of the features.
|
|
331
|
+
- Each continuous feature gets its own learnable N-dimension vector.
|
|
332
|
+
- Each categorical feature gets an embedding table that maps every category (e.g., "color=red", "color=blue") to a unique N-dimension vector.
|
|
333
|
+
|
|
334
|
+
**Attention Heads:** Controls the "Multi-Head Attention" mechanism. Instead of looking at all the feature interactions at once, the model splits its attention into N parallel heads.
|
|
335
|
+
- Embedding Dimensions get divided by the number of Attention Heads, resulting in the dimensions assigned per head.
|
|
337
336
|
|
|
338
|
-
**
|
|
337
|
+
**Number of Layers:** "Depth" of the model. Number of identical `TransformerEncoderLayer` blocks that are stacked on top of each other.
|
|
338
|
+
- Layer 1: The attention heads find simple, direct interactions between the features.
|
|
339
|
+
- Layer 2: Takes the output of Layer 1 and finds interactions between those interactions and so on.
|
|
340
|
+
- Trade-off: More layers are more powerful but are slower to train and more prone to overfitting. If the training loss goes down but the validation loss goes up, you might have too many layers (or need more dropout).
|
|
341
|
+
|
|
339
342
|
"""
|
|
340
343
|
super().__init__()
|
|
341
344
|
|
|
345
|
+
# --- Get info from schema ---
|
|
346
|
+
in_features = len(schema.feature_names)
|
|
347
|
+
categorical_index_map = schema.categorical_index_map
|
|
348
|
+
|
|
342
349
|
# --- Validation ---
|
|
343
|
-
if categorical_index_map and max(categorical_index_map.keys()) >= in_features:
|
|
350
|
+
if categorical_index_map and (max(categorical_index_map.keys()) >= in_features):
|
|
344
351
|
_LOGGER.error(f"A categorical index ({max(categorical_index_map.keys())}) is out of bounds for the provided input features ({in_features}).")
|
|
345
352
|
raise ValueError()
|
|
346
353
|
|
|
347
|
-
# --- Derive numerical indices ---
|
|
348
|
-
all_indices = set(range(in_features))
|
|
349
|
-
categorical_indices_set = set(categorical_index_map.keys())
|
|
350
|
-
numerical_indices = sorted(list(all_indices - categorical_indices_set))
|
|
351
|
-
|
|
352
354
|
# --- Save configuration ---
|
|
353
|
-
self.
|
|
355
|
+
self.schema = schema # <-- Save the whole schema
|
|
354
356
|
self.out_targets = out_targets
|
|
355
|
-
self.numerical_indices = numerical_indices
|
|
356
|
-
self.categorical_map = categorical_index_map
|
|
357
357
|
self.embedding_dim = embedding_dim
|
|
358
358
|
self.num_heads = num_heads
|
|
359
359
|
self.num_layers = num_layers
|
|
360
360
|
self.dropout = dropout
|
|
361
361
|
|
|
362
|
-
# --- 1. Feature Tokenizer ---
|
|
362
|
+
# --- 1. Feature Tokenizer (now takes the schema) ---
|
|
363
363
|
self.tokenizer = _FeatureTokenizer(
|
|
364
|
-
|
|
365
|
-
categorical_map=categorical_index_map,
|
|
364
|
+
schema=schema,
|
|
366
365
|
embedding_dim=embedding_dim
|
|
367
366
|
)
|
|
368
367
|
|
|
369
368
|
# --- 2. CLS Token ---
|
|
370
|
-
# A learnable token that will be prepended to the sequence.
|
|
371
369
|
self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
|
|
372
370
|
|
|
373
371
|
# --- 3. Transformer Encoder ---
|
|
@@ -416,21 +414,87 @@ class TabularTransformer(nn.Module, _ArchitectureHandlerMixin):
|
|
|
416
414
|
|
|
417
415
|
def get_architecture_config(self) -> Dict[str, Any]:
|
|
418
416
|
"""Returns the full configuration of the model."""
|
|
417
|
+
# Deconstruct schema into a JSON-friendly dict
|
|
418
|
+
# Tuples are saved as lists
|
|
419
|
+
schema_dict = {
|
|
420
|
+
'feature_names': self.schema.feature_names,
|
|
421
|
+
'continuous_feature_names': self.schema.continuous_feature_names,
|
|
422
|
+
'categorical_feature_names': self.schema.categorical_feature_names,
|
|
423
|
+
'categorical_index_map': self.schema.categorical_index_map,
|
|
424
|
+
'categorical_mappings': self.schema.categorical_mappings
|
|
425
|
+
}
|
|
426
|
+
|
|
419
427
|
return {
|
|
420
|
-
'
|
|
428
|
+
'schema_dict': schema_dict,
|
|
421
429
|
'out_targets': self.out_targets,
|
|
422
|
-
'categorical_map': self.categorical_map,
|
|
423
430
|
'embedding_dim': self.embedding_dim,
|
|
424
431
|
'num_heads': self.num_heads,
|
|
425
432
|
'num_layers': self.num_layers,
|
|
426
433
|
'dropout': self.dropout
|
|
427
434
|
}
|
|
435
|
+
|
|
436
|
+
@classmethod
|
|
437
|
+
def load(cls: type, file_or_dir: Union[str, Path], verbose: bool = True) -> nn.Module:
|
|
438
|
+
"""Loads a model architecture from a JSON file."""
|
|
439
|
+
user_path = make_fullpath(file_or_dir)
|
|
440
|
+
|
|
441
|
+
if user_path.is_dir():
|
|
442
|
+
json_filename = PytorchModelArchitectureKeys.SAVENAME + ".json"
|
|
443
|
+
target_path = make_fullpath(user_path / json_filename, enforce="file")
|
|
444
|
+
elif user_path.is_file():
|
|
445
|
+
target_path = user_path
|
|
446
|
+
else:
|
|
447
|
+
_LOGGER.error(f"Invalid path: '{file_or_dir}'")
|
|
448
|
+
raise IOError()
|
|
449
|
+
|
|
450
|
+
with open(target_path, 'r') as f:
|
|
451
|
+
saved_data = json.load(f)
|
|
452
|
+
|
|
453
|
+
saved_class_name = saved_data[PytorchModelArchitectureKeys.MODEL]
|
|
454
|
+
config = saved_data[PytorchModelArchitectureKeys.CONFIG]
|
|
455
|
+
|
|
456
|
+
if saved_class_name != cls.__name__:
|
|
457
|
+
_LOGGER.error(f"Model class mismatch. File specifies '{saved_class_name}', but '{cls.__name__}' was expected.")
|
|
458
|
+
raise ValueError()
|
|
459
|
+
|
|
460
|
+
# --- RECONSTRUCTION LOGIC ---
|
|
461
|
+
if 'schema_dict' not in config:
|
|
462
|
+
_LOGGER.error("Invalid architecture file: missing 'schema_dict'. This file may be from an older version.")
|
|
463
|
+
raise ValueError("Missing 'schema_dict' in config.")
|
|
464
|
+
|
|
465
|
+
schema_data = config.pop('schema_dict')
|
|
466
|
+
|
|
467
|
+
# Re-hydrate the categorical_index_map
|
|
468
|
+
# JSON saves all dict keys as strings, so we must convert them back to int.
|
|
469
|
+
raw_index_map = schema_data['categorical_index_map']
|
|
470
|
+
if raw_index_map is not None:
|
|
471
|
+
rehydrated_index_map = {int(k): v for k, v in raw_index_map.items()}
|
|
472
|
+
else:
|
|
473
|
+
rehydrated_index_map = None
|
|
474
|
+
|
|
475
|
+
# Re-hydrate the FeatureSchema object
|
|
476
|
+
# JSON deserializes tuples as lists, so we must convert them back.
|
|
477
|
+
schema = FeatureSchema(
|
|
478
|
+
feature_names=tuple(schema_data['feature_names']),
|
|
479
|
+
continuous_feature_names=tuple(schema_data['continuous_feature_names']),
|
|
480
|
+
categorical_feature_names=tuple(schema_data['categorical_feature_names']),
|
|
481
|
+
categorical_index_map=rehydrated_index_map,
|
|
482
|
+
categorical_mappings=schema_data['categorical_mappings']
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
config['schema'] = schema
|
|
486
|
+
# --- End Reconstruction ---
|
|
487
|
+
|
|
488
|
+
model = cls(**config)
|
|
489
|
+
if verbose:
|
|
490
|
+
_LOGGER.info(f"Successfully loaded architecture for '{saved_class_name}'")
|
|
491
|
+
return model
|
|
428
492
|
|
|
429
493
|
def __repr__(self) -> str:
|
|
430
494
|
"""Returns the developer-friendly string representation of the model."""
|
|
431
495
|
# Build the architecture string part-by-part
|
|
432
496
|
parts = [
|
|
433
|
-
f"Tokenizer(features={self.
|
|
497
|
+
f"Tokenizer(features={len(self.schema.feature_names)}, dim={self.embedding_dim})",
|
|
434
498
|
"[CLS]",
|
|
435
499
|
f"TransformerEncoder(layers={self.num_layers}, heads={self.num_heads})",
|
|
436
500
|
f"PredictionHead(outputs={self.out_targets})"
|
|
@@ -443,29 +507,41 @@ class TabularTransformer(nn.Module, _ArchitectureHandlerMixin):
|
|
|
443
507
|
|
|
444
508
|
class _FeatureTokenizer(nn.Module):
|
|
445
509
|
"""
|
|
446
|
-
Transforms raw numerical and categorical features from any column order
|
|
510
|
+
Transforms raw numerical and categorical features from any column order
|
|
511
|
+
into a sequence of embeddings.
|
|
447
512
|
"""
|
|
448
513
|
def __init__(self,
|
|
449
|
-
|
|
450
|
-
categorical_map: Dict[int, int],
|
|
514
|
+
schema: FeatureSchema,
|
|
451
515
|
embedding_dim: int):
|
|
452
516
|
"""
|
|
453
517
|
Args:
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
embedding_dim (int):
|
|
518
|
+
schema (FeatureSchema):
|
|
519
|
+
The definitive schema object from data_exploration.
|
|
520
|
+
embedding_dim (int):
|
|
521
|
+
The dimension for all feature embeddings.
|
|
457
522
|
"""
|
|
458
523
|
super().__init__()
|
|
459
524
|
|
|
460
|
-
#
|
|
461
|
-
|
|
462
|
-
|
|
525
|
+
# --- Get info from schema ---
|
|
526
|
+
categorical_map = schema.categorical_index_map
|
|
527
|
+
|
|
528
|
+
if categorical_map:
|
|
529
|
+
# Unpack the dictionary into separate lists
|
|
530
|
+
self.categorical_indices = list(categorical_map.keys())
|
|
531
|
+
cardinalities = list(categorical_map.values())
|
|
532
|
+
else:
|
|
533
|
+
self.categorical_indices = []
|
|
534
|
+
cardinalities = []
|
|
535
|
+
|
|
536
|
+
# Derive numerical indices by finding what's not categorical
|
|
537
|
+
all_indices = set(range(len(schema.feature_names)))
|
|
538
|
+
categorical_indices_set = set(self.categorical_indices)
|
|
539
|
+
self.numerical_indices = sorted(list(all_indices - categorical_indices_set))
|
|
463
540
|
|
|
464
|
-
self.numerical_indices = numerical_indices
|
|
465
541
|
self.embedding_dim = embedding_dim
|
|
466
542
|
|
|
467
543
|
# A learnable embedding for each numerical feature
|
|
468
|
-
self.numerical_embeddings = nn.Parameter(torch.randn(len(numerical_indices), embedding_dim))
|
|
544
|
+
self.numerical_embeddings = nn.Parameter(torch.randn(len(self.numerical_indices), embedding_dim))
|
|
469
545
|
|
|
470
546
|
# A standard embedding layer for each categorical feature
|
|
471
547
|
self.categorical_embeddings = nn.ModuleList(
|
|
@@ -487,6 +563,8 @@ class _FeatureTokenizer(nn.Module):
|
|
|
487
563
|
# Process categorical features
|
|
488
564
|
categorical_tokens = []
|
|
489
565
|
for i, embed_layer in enumerate(self.categorical_embeddings):
|
|
566
|
+
# x_categorical[:, i] selects the i-th categorical column
|
|
567
|
+
# (e.g., all values for the 'color' feature)
|
|
490
568
|
token = embed_layer(x_categorical[:, i]).unsqueeze(1)
|
|
491
569
|
categorical_tokens.append(token)
|
|
492
570
|
|
|
@@ -670,5 +748,7 @@ class SequencePredictorLSTM(nn.Module, _ArchitectureHandlerMixin):
|
|
|
670
748
|
)
|
|
671
749
|
|
|
672
750
|
|
|
751
|
+
# ---- PyTorch models ---
|
|
752
|
+
|
|
673
753
|
def info():
|
|
674
754
|
_script_info(__all__)
|