dragon-ml-toolbox 12.13.0__py3-none-any.whl → 14.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

Files changed (35) hide show
  1. {dragon_ml_toolbox-12.13.0.dist-info → dragon_ml_toolbox-14.3.0.dist-info}/METADATA +11 -2
  2. dragon_ml_toolbox-14.3.0.dist-info/RECORD +48 -0
  3. {dragon_ml_toolbox-12.13.0.dist-info → dragon_ml_toolbox-14.3.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
  4. ml_tools/MICE_imputation.py +207 -5
  5. ml_tools/ML_callbacks.py +40 -8
  6. ml_tools/ML_datasetmaster.py +200 -261
  7. ml_tools/ML_evaluation.py +29 -17
  8. ml_tools/ML_evaluation_multi.py +13 -10
  9. ml_tools/ML_inference.py +14 -5
  10. ml_tools/ML_models.py +135 -55
  11. ml_tools/ML_models_advanced.py +323 -0
  12. ml_tools/ML_optimization.py +49 -36
  13. ml_tools/ML_trainer.py +560 -30
  14. ml_tools/ML_utilities.py +302 -4
  15. ml_tools/ML_vision_datasetmaster.py +1352 -0
  16. ml_tools/ML_vision_evaluation.py +260 -0
  17. ml_tools/ML_vision_inference.py +428 -0
  18. ml_tools/ML_vision_models.py +627 -0
  19. ml_tools/ML_vision_transformers.py +58 -0
  20. ml_tools/PSO_optimization.py +5 -1
  21. ml_tools/_ML_vision_recipe.py +88 -0
  22. ml_tools/__init__.py +1 -0
  23. ml_tools/_schema.py +96 -0
  24. ml_tools/custom_logger.py +37 -14
  25. ml_tools/data_exploration.py +576 -138
  26. ml_tools/keys.py +51 -1
  27. ml_tools/math_utilities.py +1 -1
  28. ml_tools/optimization_tools.py +65 -86
  29. ml_tools/serde.py +78 -17
  30. ml_tools/utilities.py +192 -3
  31. dragon_ml_toolbox-12.13.0.dist-info/RECORD +0 -41
  32. ml_tools/ML_simple_optimization.py +0 -413
  33. {dragon_ml_toolbox-12.13.0.dist-info → dragon_ml_toolbox-14.3.0.dist-info}/WHEEL +0 -0
  34. {dragon_ml_toolbox-12.13.0.dist-info → dragon_ml_toolbox-14.3.0.dist-info}/licenses/LICENSE +0 -0
  35. {dragon_ml_toolbox-12.13.0.dist-info → dragon_ml_toolbox-14.3.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_evaluation.py CHANGED
@@ -19,11 +19,12 @@ import torch
19
19
  import shap
20
20
  from pathlib import Path
21
21
  from typing import Union, Optional, List, Literal
22
+ import warnings
22
23
 
23
24
  from .path_manager import make_fullpath
24
25
  from ._logger import _LOGGER
25
26
  from ._script_info import _script_info
26
- from .keys import SHAPKeys
27
+ from .keys import SHAPKeys, PyTorchLogKeys
27
28
 
28
29
 
29
30
  __all__ = [
@@ -43,8 +44,8 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
43
44
  history (dict): A dictionary containing 'train_loss' and 'val_loss'.
44
45
  save_dir (str | Path): Directory to save the plot image.
45
46
  """
46
- train_loss = history.get('train_loss', [])
47
- val_loss = history.get('val_loss', [])
47
+ train_loss = history.get(PyTorchLogKeys.TRAIN_LOSS, [])
48
+ val_loss = history.get(PyTorchLogKeys.VAL_LOSS, [])
48
49
 
49
50
  if not train_loss and not val_loss:
50
51
  print("Warning: Loss history is empty or incomplete. Cannot plot.")
@@ -257,7 +258,7 @@ def shap_summary_plot(model,
257
258
  feature_names: Optional[list[str]],
258
259
  save_dir: Union[str, Path],
259
260
  device: torch.device = torch.device('cpu'),
260
- explainer_type: Literal['deep', 'kernel'] = 'deep'):
261
+ explainer_type: Literal['deep', 'kernel'] = 'kernel'):
261
262
  """
262
263
  Calculates SHAP values and saves summary plots and data.
263
264
 
@@ -269,7 +270,7 @@ def shap_summary_plot(model,
269
270
  save_dir (str | Path): Directory to save SHAP artifacts.
270
271
  device (torch.device): The torch device for SHAP calculations.
271
272
  explainer_type (Literal['deep', 'kernel']): The explainer to use.
272
- - 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient for
273
+ - 'deep': Uses shap.DeepExplainer. Fast and efficient for
273
274
  PyTorch models.
274
275
  - 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY
275
276
  slow and memory-intensive.
@@ -284,7 +285,7 @@ def shap_summary_plot(model,
284
285
  instances_to_explain_np = None
285
286
 
286
287
  if explainer_type == 'deep':
287
- # --- 1. Use DeepExplainer (Preferred) ---
288
+ # --- 1. Use DeepExplainer ---
288
289
 
289
290
  # Ensure data is torch.Tensor
290
291
  if isinstance(background_data, np.ndarray):
@@ -298,17 +299,19 @@ def shap_summary_plot(model,
298
299
 
299
300
  background_data = background_data.to(device)
300
301
  instances_to_explain = instances_to_explain.to(device)
301
-
302
- explainer = shap.DeepExplainer(model, background_data)
302
+
303
+ with warnings.catch_warnings():
304
+ warnings.simplefilter("ignore", category=UserWarning)
305
+ explainer = shap.DeepExplainer(model, background_data)
306
+
303
307
  # print("Calculating SHAP values with DeepExplainer...")
304
308
  shap_values = explainer.shap_values(instances_to_explain)
305
309
  instances_to_explain_np = instances_to_explain.cpu().numpy()
306
310
 
307
311
  elif explainer_type == 'kernel':
308
- # --- 2. Use KernelExplainer (Slow Fallback) ---
312
+ # --- 2. Use KernelExplainer ---
309
313
  _LOGGER.warning(
310
- "Using KernelExplainer. This is memory-intensive and slow. "
311
- "Consider reducing 'n_samples' if the process terminates unexpectedly."
314
+ "KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
312
315
  )
313
316
 
314
317
  # Ensure data is np.ndarray
@@ -344,14 +347,26 @@ def shap_summary_plot(model,
344
347
  else:
345
348
  _LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
346
349
  raise ValueError()
350
+
351
+ if not isinstance(shap_values, list) and shap_values.ndim == 3 and shap_values.shape[2] == 1:
352
+ # _LOGGER.info("Squeezing SHAP values from (N, F, 1) to (N, F) for regression plot.")
353
+ shap_values = shap_values.squeeze(-1)
347
354
 
348
355
  # --- 3. Plotting and Saving ---
349
356
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
350
357
  plt.ioff()
351
358
 
359
+ # Convert instances to a DataFrame. robust way to ensure SHAP correctly maps values to feature names.
360
+ if feature_names is None:
361
+ # Create generic names if none were provided
362
+ num_features = instances_to_explain_np.shape[1]
363
+ feature_names = [f'feature_{i}' for i in range(num_features)]
364
+
365
+ instances_df = pd.DataFrame(instances_to_explain_np, columns=feature_names)
366
+
352
367
  # Save Bar Plot
353
368
  bar_path = save_dir_path / "shap_bar_plot.svg"
354
- shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="bar", show=False)
369
+ shap.summary_plot(shap_values, instances_df, plot_type="bar", show=False)
355
370
  ax = plt.gca()
356
371
  ax.set_xlabel("SHAP Value Impact", labelpad=10)
357
372
  plt.title("SHAP Feature Importance")
@@ -362,7 +377,7 @@ def shap_summary_plot(model,
362
377
 
363
378
  # Save Dot Plot
364
379
  dot_path = save_dir_path / "shap_dot_plot.svg"
365
- shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
380
+ shap.summary_plot(shap_values, instances_df, plot_type="dot", show=False)
366
381
  ax = plt.gca()
367
382
  ax.set_xlabel("SHAP Value Impact", labelpad=10)
368
383
  if plt.gcf().axes and len(plt.gcf().axes) > 1:
@@ -385,9 +400,6 @@ def shap_summary_plot(model,
385
400
  mean_abs_shap = np.abs(shap_values).mean(axis=0)
386
401
 
387
402
  mean_abs_shap = mean_abs_shap.flatten()
388
-
389
- if feature_names is None:
390
- feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
391
403
 
392
404
  summary_df = pd.DataFrame({
393
405
  SHAPKeys.FEATURE_COLUMN: feature_names,
@@ -397,7 +409,7 @@ def shap_summary_plot(model,
397
409
  summary_df.to_csv(summary_path, index=False)
398
410
 
399
411
  _LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
400
- plt.ion()
412
+ plt.ion()
401
413
 
402
414
 
403
415
  def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path], top_n: int = 10):
@@ -20,6 +20,7 @@ from sklearn.metrics import (
20
20
  )
21
21
  from pathlib import Path
22
22
  from typing import Union, List, Literal
23
+ import warnings
23
24
 
24
25
  from .path_manager import make_fullpath, sanitize_filename
25
26
  from ._logger import _LOGGER
@@ -234,7 +235,7 @@ def multi_target_shap_summary_plot(
234
235
  target_names: List[str],
235
236
  save_dir: Union[str, Path],
236
237
  device: torch.device = torch.device('cpu'),
237
- explainer_type: Literal['deep', 'kernel'] = 'deep'
238
+ explainer_type: Literal['deep', 'kernel'] = 'kernel'
238
239
  ):
239
240
  """
240
241
  Calculates SHAP values for a multi-target model and saves summary plots and data for each target.
@@ -248,7 +249,7 @@ def multi_target_shap_summary_plot(
248
249
  save_dir (str | Path): Directory to save SHAP artifacts.
249
250
  device (torch.device): The torch device for SHAP calculations.
250
251
  explainer_type (Literal['deep', 'kernel']): The explainer to use.
251
- - 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient.
252
+ - 'deep': Uses shap.DeepExplainer. Fast and efficient.
252
253
  - 'kernel': Uses shap.KernelExplainer. Model-agnostic but slow and memory-intensive.
253
254
  """
254
255
  _LOGGER.info(f"--- Multi-Target SHAP Value Explanation (Using: {explainer_type.upper()}Explainer) ---")
@@ -259,7 +260,7 @@ def multi_target_shap_summary_plot(
259
260
  instances_to_explain_np = None
260
261
 
261
262
  if explainer_type == 'deep':
262
- # --- 1. Use DeepExplainer (Preferred) ---
263
+ # --- 1. Use DeepExplainer ---
263
264
 
264
265
  # Ensure data is torch.Tensor
265
266
  if isinstance(background_data, np.ndarray):
@@ -273,18 +274,20 @@ def multi_target_shap_summary_plot(
273
274
 
274
275
  background_data = background_data.to(device)
275
276
  instances_to_explain = instances_to_explain.to(device)
276
-
277
- explainer = shap.DeepExplainer(model, background_data)
278
- print("Calculating SHAP values with DeepExplainer...")
277
+
278
+ with warnings.catch_warnings():
279
+ warnings.simplefilter("ignore", category=UserWarning)
280
+ explainer = shap.DeepExplainer(model, background_data)
281
+
282
+ # print("Calculating SHAP values with DeepExplainer...")
279
283
  # DeepExplainer returns a list of arrays for multi-output models
280
284
  shap_values_list = explainer.shap_values(instances_to_explain)
281
285
  instances_to_explain_np = instances_to_explain.cpu().numpy()
282
286
 
283
287
  elif explainer_type == 'kernel':
284
- # --- 2. Use KernelExplainer (Slow Fallback) ---
288
+ # --- 2. Use KernelExplainer ---
285
289
  _LOGGER.warning(
286
- "Using KernelExplainer. This is memory-intensive and slow. "
287
- "Consider reducing 'n_samples' if the process terminates."
290
+ "KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
288
291
  )
289
292
 
290
293
  # Convert all data to numpy
@@ -304,7 +307,7 @@ def multi_target_shap_summary_plot(
304
307
  return output.cpu().numpy() # Return full multi-output array
305
308
 
306
309
  explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
307
- print("Calculating SHAP values with KernelExplainer...")
310
+ # print("Calculating SHAP values with KernelExplainer...")
308
311
  # KernelExplainer also returns a list of arrays for multi-output models
309
312
  shap_values_list = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
310
313
  # instances_to_explain_np is already set
ml_tools/ML_inference.py CHANGED
@@ -9,7 +9,7 @@ from .ML_scaler import PytorchScaler
9
9
  from ._script_info import _script_info
10
10
  from ._logger import _LOGGER
11
11
  from .path_manager import make_fullpath
12
- from .keys import PyTorchInferenceKeys
12
+ from .keys import PyTorchInferenceKeys, PyTorchCheckpointKeys
13
13
 
14
14
 
15
15
  __all__ = [
@@ -56,11 +56,21 @@ class _BaseInferenceHandler(ABC):
56
56
  model_p = make_fullpath(state_dict, enforce="file")
57
57
 
58
58
  try:
59
- # Load the state dictionary and apply it to the model structure
60
- self.model.load_state_dict(torch.load(model_p, map_location=self.device))
59
+ # Load whatever is in the file
60
+ loaded_data = torch.load(model_p, map_location=self.device)
61
+
62
+ # Check if it's the new checkpoint dictionary or an old weights-only file
63
+ if isinstance(loaded_data, dict) and PyTorchCheckpointKeys.MODEL_STATE in loaded_data:
64
+ # It's a new training checkpoint, extract the weights
65
+ self.model.load_state_dict(loaded_data[PyTorchCheckpointKeys.MODEL_STATE])
66
+ else:
67
+ # It's an old-style file (or just a state_dict), load it directly
68
+ self.model.load_state_dict(loaded_data)
69
+
70
+ _LOGGER.info(f"Model state loaded from '{model_p.name}'.")
71
+
61
72
  self.model.to(self.device)
62
73
  self.model.eval() # Set the model to evaluation mode
63
- _LOGGER.info(f"Model state loaded from '{model_p.name}' and set to evaluation mode.")
64
74
  except Exception as e:
65
75
  _LOGGER.error(f"Failed to load model state from '{model_p}': {e}")
66
76
  raise
@@ -72,7 +82,6 @@ class _BaseInferenceHandler(ABC):
72
82
  _LOGGER.warning("CUDA not available, switching to CPU.")
73
83
  device_lower = "cpu"
74
84
  elif device_lower == "mps" and not torch.backends.mps.is_available():
75
- # Your M-series Mac will appreciate this check!
76
85
  _LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
77
86
  device_lower = "cpu"
78
87
  return torch.device(device_lower)
ml_tools/ML_models.py CHANGED
@@ -8,6 +8,7 @@ from ._logger import _LOGGER
8
8
  from .path_manager import make_fullpath
9
9
  from ._script_info import _script_info
10
10
  from .keys import PytorchModelArchitectureKeys
11
+ from ._schema import FeatureSchema
11
12
 
12
13
 
13
14
  __all__ = [
@@ -298,76 +299,73 @@ class TabularTransformer(nn.Module, _ArchitectureHandlerMixin):
298
299
  """
299
300
  A Transformer-based model for tabular data tasks.
300
301
 
301
- This model uses a Feature Tokenizer to convert all input features into a sequence of embeddings, prepends a [CLS] token, and processes the
302
+ This model uses a Feature Tokenizer to convert all input features into a
303
+ sequence of embeddings, prepends a [CLS] token, and processes the
302
304
  sequence with a standard Transformer Encoder.
303
305
  """
304
306
  def __init__(self, *,
305
- in_features: int,
307
+ schema: FeatureSchema,
306
308
  out_targets: int,
307
- categorical_index_map: Dict[int, int],
308
- embedding_dim: int = 32,
309
+ embedding_dim: int = 256,
309
310
  num_heads: int = 8,
310
311
  num_layers: int = 6,
311
- dropout: float = 0.1):
312
+ dropout: float = 0.2):
312
313
  """
313
314
  Args:
314
- in_features (int): The total number of columns in the input data (features).
315
- out_targets (int): Number of output targets (1 for regression).
316
- categorical_index_map (Dict[int, int]): Maps categorical column index to its cardinality (number of unique categories).
317
- embedding_dim (int): The dimension for all feature embeddings. Must be divisible by num_heads.
318
- num_heads (int): The number of heads in the multi-head attention mechanism.
319
- num_layers (int): The number of sub-encoder-layers in the transformer encoder.
320
- dropout (float): The dropout value.
321
-
322
- Note:
323
- - All arguments are keyword-only to promote clarity.
324
- - Column indices start at 0.
315
+ schema (FeatureSchema):
316
+ The definitive schema object created by `data_exploration.finalize_feature_schema()`.
317
+ out_targets (int):
318
+ Number of output targets (1 for regression).
319
+ embedding_dim (int):
320
+ The dimension for all feature embeddings. Must be divisible by num_heads. Common values: (64, 128, 192, 256, etc.)
321
+ num_heads (int):
322
+ The number of heads in the multi-head attention mechanism. Common values: (4, 8, 16)
323
+ num_layers (int):
324
+ The number of sub-encoder-layers in the transformer encoder. Common values: (4, 8, 12)
325
+ dropout (float):
326
+ The dropout value.
327
+
328
+ ## Note:
325
329
 
326
- ### Data Preparation
327
- The model requires a specific input format. All columns in the input DataFrame must be numerical, but they are treated differently based on the
328
- provided index lists.
329
-
330
- **Nominal Categorical Features** (e.g., 'City', 'Color'): Should **NOT** be one-hot encoded.
331
- Instead, convert them to integer codes (label encoding). You must then provide a dictionary mapping their column indices to
332
- their cardinality (the number of unique categories) via the `categorical_map` parameter.
333
-
334
- **Ordinal & Binary Features** (e.g., 'Low/Medium/High', 'True/False'): Should be treated as **numerical**. Map them to numbers that
335
- represent their state (e.g., `{'Low': 0, 'Medium': 1}` or `{False: 0, True: 1}`). Their column indices should **NOT** be included in the
336
- `categorical_map` parameter.
330
+ **Embedding Dimension:** "Width" of the model. It's the N-dimension vector that will be used to represent each one of the features.
331
+ - Each continuous feature gets its own learnable N-dimension vector.
332
+ - Each categorical feature gets an embedding table that maps every category (e.g., "color=red", "color=blue") to a unique N-dimension vector.
333
+
334
+ **Attention Heads:** Controls the "Multi-Head Attention" mechanism. Instead of looking at all the feature interactions at once, the model splits its attention into N parallel heads.
335
+ - Embedding Dimensions get divided by the number of Attention Heads, resulting in the dimensions assigned per head.
337
336
 
338
- **Standard Numerical and Continuous Features** (e.g., 'Age', 'Price'): It is highly recommended to scale them before training.
337
+ **Number of Layers:** "Depth" of the model. Number of identical `TransformerEncoderLayer` blocks that are stacked on top of each other.
338
+ - Layer 1: The attention heads find simple, direct interactions between the features.
339
+ - Layer 2: Takes the output of Layer 1 and finds interactions between those interactions and so on.
340
+ - Trade-off: More layers are more powerful but are slower to train and more prone to overfitting. If the training loss goes down but the validation loss goes up, you might have too many layers (or need more dropout).
341
+
339
342
  """
340
343
  super().__init__()
341
344
 
345
+ # --- Get info from schema ---
346
+ in_features = len(schema.feature_names)
347
+ categorical_index_map = schema.categorical_index_map
348
+
342
349
  # --- Validation ---
343
- if categorical_index_map and max(categorical_index_map.keys()) >= in_features:
350
+ if categorical_index_map and (max(categorical_index_map.keys()) >= in_features):
344
351
  _LOGGER.error(f"A categorical index ({max(categorical_index_map.keys())}) is out of bounds for the provided input features ({in_features}).")
345
352
  raise ValueError()
346
353
 
347
- # --- Derive numerical indices ---
348
- all_indices = set(range(in_features))
349
- categorical_indices_set = set(categorical_index_map.keys())
350
- numerical_indices = sorted(list(all_indices - categorical_indices_set))
351
-
352
354
  # --- Save configuration ---
353
- self.in_features = in_features
355
+ self.schema = schema # <-- Save the whole schema
354
356
  self.out_targets = out_targets
355
- self.numerical_indices = numerical_indices
356
- self.categorical_map = categorical_index_map
357
357
  self.embedding_dim = embedding_dim
358
358
  self.num_heads = num_heads
359
359
  self.num_layers = num_layers
360
360
  self.dropout = dropout
361
361
 
362
- # --- 1. Feature Tokenizer ---
362
+ # --- 1. Feature Tokenizer (now takes the schema) ---
363
363
  self.tokenizer = _FeatureTokenizer(
364
- numerical_indices=numerical_indices,
365
- categorical_map=categorical_index_map,
364
+ schema=schema,
366
365
  embedding_dim=embedding_dim
367
366
  )
368
367
 
369
368
  # --- 2. CLS Token ---
370
- # A learnable token that will be prepended to the sequence.
371
369
  self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
372
370
 
373
371
  # --- 3. Transformer Encoder ---
@@ -416,21 +414,87 @@ class TabularTransformer(nn.Module, _ArchitectureHandlerMixin):
416
414
 
417
415
  def get_architecture_config(self) -> Dict[str, Any]:
418
416
  """Returns the full configuration of the model."""
417
+ # Deconstruct schema into a JSON-friendly dict
418
+ # Tuples are saved as lists
419
+ schema_dict = {
420
+ 'feature_names': self.schema.feature_names,
421
+ 'continuous_feature_names': self.schema.continuous_feature_names,
422
+ 'categorical_feature_names': self.schema.categorical_feature_names,
423
+ 'categorical_index_map': self.schema.categorical_index_map,
424
+ 'categorical_mappings': self.schema.categorical_mappings
425
+ }
426
+
419
427
  return {
420
- 'in_features': self.in_features,
428
+ 'schema_dict': schema_dict,
421
429
  'out_targets': self.out_targets,
422
- 'categorical_map': self.categorical_map,
423
430
  'embedding_dim': self.embedding_dim,
424
431
  'num_heads': self.num_heads,
425
432
  'num_layers': self.num_layers,
426
433
  'dropout': self.dropout
427
434
  }
435
+
436
+ @classmethod
437
+ def load(cls: type, file_or_dir: Union[str, Path], verbose: bool = True) -> nn.Module:
438
+ """Loads a model architecture from a JSON file."""
439
+ user_path = make_fullpath(file_or_dir)
440
+
441
+ if user_path.is_dir():
442
+ json_filename = PytorchModelArchitectureKeys.SAVENAME + ".json"
443
+ target_path = make_fullpath(user_path / json_filename, enforce="file")
444
+ elif user_path.is_file():
445
+ target_path = user_path
446
+ else:
447
+ _LOGGER.error(f"Invalid path: '{file_or_dir}'")
448
+ raise IOError()
449
+
450
+ with open(target_path, 'r') as f:
451
+ saved_data = json.load(f)
452
+
453
+ saved_class_name = saved_data[PytorchModelArchitectureKeys.MODEL]
454
+ config = saved_data[PytorchModelArchitectureKeys.CONFIG]
455
+
456
+ if saved_class_name != cls.__name__:
457
+ _LOGGER.error(f"Model class mismatch. File specifies '{saved_class_name}', but '{cls.__name__}' was expected.")
458
+ raise ValueError()
459
+
460
+ # --- RECONSTRUCTION LOGIC ---
461
+ if 'schema_dict' not in config:
462
+ _LOGGER.error("Invalid architecture file: missing 'schema_dict'. This file may be from an older version.")
463
+ raise ValueError("Missing 'schema_dict' in config.")
464
+
465
+ schema_data = config.pop('schema_dict')
466
+
467
+ # Re-hydrate the categorical_index_map
468
+ # JSON saves all dict keys as strings, so we must convert them back to int.
469
+ raw_index_map = schema_data['categorical_index_map']
470
+ if raw_index_map is not None:
471
+ rehydrated_index_map = {int(k): v for k, v in raw_index_map.items()}
472
+ else:
473
+ rehydrated_index_map = None
474
+
475
+ # Re-hydrate the FeatureSchema object
476
+ # JSON deserializes tuples as lists, so we must convert them back.
477
+ schema = FeatureSchema(
478
+ feature_names=tuple(schema_data['feature_names']),
479
+ continuous_feature_names=tuple(schema_data['continuous_feature_names']),
480
+ categorical_feature_names=tuple(schema_data['categorical_feature_names']),
481
+ categorical_index_map=rehydrated_index_map,
482
+ categorical_mappings=schema_data['categorical_mappings']
483
+ )
484
+
485
+ config['schema'] = schema
486
+ # --- End Reconstruction ---
487
+
488
+ model = cls(**config)
489
+ if verbose:
490
+ _LOGGER.info(f"Successfully loaded architecture for '{saved_class_name}'")
491
+ return model
428
492
 
429
493
  def __repr__(self) -> str:
430
494
  """Returns the developer-friendly string representation of the model."""
431
495
  # Build the architecture string part-by-part
432
496
  parts = [
433
- f"Tokenizer(features={self.in_features}, dim={self.embedding_dim})",
497
+ f"Tokenizer(features={len(self.schema.feature_names)}, dim={self.embedding_dim})",
434
498
  "[CLS]",
435
499
  f"TransformerEncoder(layers={self.num_layers}, heads={self.num_heads})",
436
500
  f"PredictionHead(outputs={self.out_targets})"
@@ -443,29 +507,41 @@ class TabularTransformer(nn.Module, _ArchitectureHandlerMixin):
443
507
 
444
508
  class _FeatureTokenizer(nn.Module):
445
509
  """
446
- Transforms raw numerical and categorical features from any column order into a sequence of embeddings.
510
+ Transforms raw numerical and categorical features from any column order
511
+ into a sequence of embeddings.
447
512
  """
448
513
  def __init__(self,
449
- numerical_indices: List[int],
450
- categorical_map: Dict[int, int],
514
+ schema: FeatureSchema,
451
515
  embedding_dim: int):
452
516
  """
453
517
  Args:
454
- numerical_indices (List[int]): A list of column indices for the numerical features.
455
- categorical_map (Dict[int, int]): A dictionary mapping each categorical column index to its cardinality (number of unique categories).
456
- embedding_dim (int): The dimension for all feature embeddings.
518
+ schema (FeatureSchema):
519
+ The definitive schema object from data_exploration.
520
+ embedding_dim (int):
521
+ The dimension for all feature embeddings.
457
522
  """
458
523
  super().__init__()
459
524
 
460
- # Unpack the dictionary into separate lists for indices and cardinalities
461
- self.categorical_indices = list(categorical_map.keys())
462
- cardinalities = list(categorical_map.values())
525
+ # --- Get info from schema ---
526
+ categorical_map = schema.categorical_index_map
527
+
528
+ if categorical_map:
529
+ # Unpack the dictionary into separate lists
530
+ self.categorical_indices = list(categorical_map.keys())
531
+ cardinalities = list(categorical_map.values())
532
+ else:
533
+ self.categorical_indices = []
534
+ cardinalities = []
535
+
536
+ # Derive numerical indices by finding what's not categorical
537
+ all_indices = set(range(len(schema.feature_names)))
538
+ categorical_indices_set = set(self.categorical_indices)
539
+ self.numerical_indices = sorted(list(all_indices - categorical_indices_set))
463
540
 
464
- self.numerical_indices = numerical_indices
465
541
  self.embedding_dim = embedding_dim
466
542
 
467
543
  # A learnable embedding for each numerical feature
468
- self.numerical_embeddings = nn.Parameter(torch.randn(len(numerical_indices), embedding_dim))
544
+ self.numerical_embeddings = nn.Parameter(torch.randn(len(self.numerical_indices), embedding_dim))
469
545
 
470
546
  # A standard embedding layer for each categorical feature
471
547
  self.categorical_embeddings = nn.ModuleList(
@@ -487,6 +563,8 @@ class _FeatureTokenizer(nn.Module):
487
563
  # Process categorical features
488
564
  categorical_tokens = []
489
565
  for i, embed_layer in enumerate(self.categorical_embeddings):
566
+ # x_categorical[:, i] selects the i-th categorical column
567
+ # (e.g., all values for the 'color' feature)
490
568
  token = embed_layer(x_categorical[:, i]).unsqueeze(1)
491
569
  categorical_tokens.append(token)
492
570
 
@@ -670,5 +748,7 @@ class SequencePredictorLSTM(nn.Module, _ArchitectureHandlerMixin):
670
748
  )
671
749
 
672
750
 
751
+ # ---- PyTorch models ---
752
+
673
753
  def info():
674
754
  _script_info(__all__)