dragon-ml-toolbox 10.1.1__py3-none-any.whl → 14.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

Files changed (48) hide show
  1. {dragon_ml_toolbox-10.1.1.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/METADATA +38 -63
  2. dragon_ml_toolbox-14.2.0.dist-info/RECORD +48 -0
  3. {dragon_ml_toolbox-10.1.1.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/licenses/LICENSE +1 -1
  4. {dragon_ml_toolbox-10.1.1.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +11 -0
  5. ml_tools/ETL_cleaning.py +175 -59
  6. ml_tools/ETL_engineering.py +506 -70
  7. ml_tools/GUI_tools.py +2 -1
  8. ml_tools/MICE_imputation.py +212 -7
  9. ml_tools/ML_callbacks.py +73 -40
  10. ml_tools/ML_datasetmaster.py +267 -284
  11. ml_tools/ML_evaluation.py +119 -58
  12. ml_tools/ML_evaluation_multi.py +107 -32
  13. ml_tools/ML_inference.py +15 -5
  14. ml_tools/ML_models.py +234 -170
  15. ml_tools/ML_models_advanced.py +323 -0
  16. ml_tools/ML_optimization.py +321 -97
  17. ml_tools/ML_scaler.py +10 -5
  18. ml_tools/ML_trainer.py +585 -40
  19. ml_tools/ML_utilities.py +528 -0
  20. ml_tools/ML_vision_datasetmaster.py +1315 -0
  21. ml_tools/ML_vision_evaluation.py +260 -0
  22. ml_tools/ML_vision_inference.py +428 -0
  23. ml_tools/ML_vision_models.py +627 -0
  24. ml_tools/ML_vision_transformers.py +58 -0
  25. ml_tools/PSO_optimization.py +10 -7
  26. ml_tools/RNN_forecast.py +2 -0
  27. ml_tools/SQL.py +22 -9
  28. ml_tools/VIF_factor.py +4 -3
  29. ml_tools/_ML_vision_recipe.py +88 -0
  30. ml_tools/__init__.py +1 -0
  31. ml_tools/_logger.py +0 -2
  32. ml_tools/_schema.py +96 -0
  33. ml_tools/constants.py +79 -0
  34. ml_tools/custom_logger.py +164 -16
  35. ml_tools/data_exploration.py +1092 -109
  36. ml_tools/ensemble_evaluation.py +48 -1
  37. ml_tools/ensemble_inference.py +6 -7
  38. ml_tools/ensemble_learning.py +4 -3
  39. ml_tools/handle_excel.py +1 -0
  40. ml_tools/keys.py +80 -0
  41. ml_tools/math_utilities.py +259 -0
  42. ml_tools/optimization_tools.py +198 -24
  43. ml_tools/path_manager.py +144 -45
  44. ml_tools/serde.py +192 -0
  45. ml_tools/utilities.py +287 -227
  46. dragon_ml_toolbox-10.1.1.dist-info/RECORD +0 -36
  47. {dragon_ml_toolbox-10.1.1.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/WHEEL +0 -0
  48. {dragon_ml_toolbox-10.1.1.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_models.py CHANGED
@@ -3,9 +3,13 @@ from torch import nn
3
3
  from typing import List, Union, Tuple, Dict, Any
4
4
  from pathlib import Path
5
5
  import json
6
+
6
7
  from ._logger import _LOGGER
7
8
  from .path_manager import make_fullpath
8
9
  from ._script_info import _script_info
10
+ from .keys import PytorchModelArchitectureKeys
11
+ from ._schema import FeatureSchema
12
+
9
13
 
10
14
  __all__ = [
11
15
  "MultilayerPerceptron",
@@ -13,12 +17,67 @@ __all__ = [
13
17
  "MultiHeadAttentionMLP",
14
18
  "TabularTransformer",
15
19
  "SequencePredictorLSTM",
16
- "save_architecture",
17
- "load_architecture"
18
20
  ]
19
21
 
20
22
 
21
- class _BaseMLP(nn.Module):
23
+ class _ArchitectureHandlerMixin:
24
+ """
25
+ A mixin class to provide save and load functionality for model architectures.
26
+ """
27
+ def save(self: nn.Module, directory: Union[str, Path], verbose: bool = True): # type: ignore
28
+ """Saves the model's architecture to a JSON file."""
29
+ if not hasattr(self, 'get_architecture_config'):
30
+ _LOGGER.error(f"Model '{self.__class__.__name__}' must have a 'get_architecture_config()' method to use this functionality.")
31
+ raise AttributeError()
32
+
33
+ path_dir = make_fullpath(directory, make=True, enforce="directory")
34
+
35
+ json_filename = PytorchModelArchitectureKeys.SAVENAME + ".json"
36
+
37
+ full_path = path_dir / json_filename
38
+
39
+ config = {
40
+ PytorchModelArchitectureKeys.MODEL: self.__class__.__name__,
41
+ PytorchModelArchitectureKeys.CONFIG: self.get_architecture_config() # type: ignore
42
+ }
43
+
44
+ with open(full_path, 'w') as f:
45
+ json.dump(config, f, indent=4)
46
+
47
+ if verbose:
48
+ _LOGGER.info(f"Architecture for '{self.__class__.__name__}' saved as '{full_path.name}'")
49
+
50
+ @classmethod
51
+ def load(cls: type, file_or_dir: Union[str, Path], verbose: bool = True) -> nn.Module:
52
+ """Loads a model architecture from a JSON file. If a directory is provided, the function will attempt to load a JSON file inside."""
53
+ user_path = make_fullpath(file_or_dir)
54
+
55
+ if user_path.is_dir():
56
+ json_filename = PytorchModelArchitectureKeys.SAVENAME + ".json"
57
+ target_path = make_fullpath(user_path / json_filename, enforce="file")
58
+ elif user_path.is_file():
59
+ target_path = user_path
60
+ else:
61
+ _LOGGER.error(f"Invalid path: '{file_or_dir}'")
62
+ raise IOError()
63
+
64
+ with open(target_path, 'r') as f:
65
+ saved_data = json.load(f)
66
+
67
+ saved_class_name = saved_data[PytorchModelArchitectureKeys.MODEL]
68
+ config = saved_data[PytorchModelArchitectureKeys.CONFIG]
69
+
70
+ if saved_class_name != cls.__name__:
71
+ _LOGGER.error(f"Model class mismatch. File specifies '{saved_class_name}', but '{cls.__name__}' was expected.")
72
+ raise ValueError()
73
+
74
+ model = cls(**config)
75
+ if verbose:
76
+ _LOGGER.info(f"Successfully loaded architecture for '{saved_class_name}'")
77
+ return model
78
+
79
+
80
+ class _BaseMLP(nn.Module, _ArchitectureHandlerMixin):
22
81
  """
23
82
  A base class for Multilayer Perceptrons.
24
83
 
@@ -68,7 +127,7 @@ class _BaseMLP(nn.Module):
68
127
  # Set a customizable Prediction Head for flexibility, specially in transfer learning and fine-tuning
69
128
  self.output_layer = nn.Linear(current_features, out_targets)
70
129
 
71
- def get_config(self) -> Dict[str, Any]:
130
+ def get_architecture_config(self) -> Dict[str, Any]:
72
131
  """Returns the base configuration of the model."""
73
132
  return {
74
133
  'in_features': self.in_features,
@@ -90,6 +149,31 @@ class _BaseMLP(nn.Module):
90
149
  return f"{name}(arch: {arch_str})"
91
150
 
92
151
 
152
+ class _BaseAttention(_BaseMLP):
153
+ """
154
+ Abstract base class for MLP models that incorporate an attention mechanism
155
+ before the main MLP layers.
156
+ """
157
+ def __init__(self, *args, **kwargs):
158
+ super().__init__(*args, **kwargs)
159
+ # By default, models inheriting this do not have the flag.
160
+ self.attention = None
161
+ self.has_interpretable_attention = False
162
+
163
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
164
+ """Defines the standard forward pass."""
165
+ logits, _attention_weights = self.forward_attention(x)
166
+ return logits
167
+
168
+ def forward_attention(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
169
+ """Returns logits and attention weights."""
170
+ # This logic is now shared and defined in one place
171
+ x, attention_weights = self.attention(x) # type: ignore
172
+ x = self.mlp(x)
173
+ logits = self.output_layer(x)
174
+ return logits, attention_weights
175
+
176
+
93
177
  class MultilayerPerceptron(_BaseMLP):
94
178
  """
95
179
  Creates a versatile Multilayer Perceptron (MLP) for regression or classification tasks.
@@ -127,7 +211,7 @@ class MultilayerPerceptron(_BaseMLP):
127
211
  return self._repr_helper(name="MultilayerPerceptron", mlp_layers=layer_sizes)
128
212
 
129
213
 
130
- class AttentionMLP(_BaseMLP):
214
+ class AttentionMLP(_BaseAttention):
131
215
  """
132
216
  A Multilayer Perceptron (MLP) that incorporates an Attention layer to dynamically weigh input features.
133
217
 
@@ -148,25 +232,7 @@ class AttentionMLP(_BaseMLP):
148
232
  super().__init__(in_features, out_targets, hidden_layers, drop_out)
149
233
  # Attention
150
234
  self.attention = _AttentionLayer(in_features)
151
-
152
- def forward(self, x: torch.Tensor) -> torch.Tensor:
153
- """
154
- Defines the standard forward pass.
155
- """
156
- logits, _attention_weights = self.forward_attention(x)
157
- return logits
158
-
159
- def forward_attention(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
160
- """
161
- Returns logits and attention weights
162
- """
163
- # The attention layer returns the processed x and the weights
164
- x, attention_weights = self.attention(x)
165
-
166
- # Pass the attention-modified tensor through the MLP
167
- logits = self.mlp(x)
168
-
169
- return logits, attention_weights
235
+ self.has_interpretable_attention = True
170
236
 
171
237
  def __repr__(self) -> str:
172
238
  """Returns the developer-friendly string representation of the model."""
@@ -181,7 +247,7 @@ class AttentionMLP(_BaseMLP):
181
247
  return self._repr_helper(name="AttentionMLP", mlp_layers=arch)
182
248
 
183
249
 
184
- class MultiHeadAttentionMLP(_BaseMLP):
250
+ class MultiHeadAttentionMLP(_BaseAttention):
185
251
  """
186
252
  An MLP that incorporates a standard `nn.MultiheadAttention` layer to process
187
253
  the input features.
@@ -210,27 +276,9 @@ class MultiHeadAttentionMLP(_BaseMLP):
210
276
  dropout=attention_dropout
211
277
  )
212
278
 
213
- def forward(self, x: torch.Tensor) -> torch.Tensor:
214
- """Defines the standard forward pass of the model."""
215
- logits, _attention_weights = self.forward_attention(x)
216
- return logits
217
-
218
- def forward_attention(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
219
- """
220
- Returns logits and attention weights.
221
- """
222
- # The attention layer returns the processed x and the weights
223
- x, attention_weights = self.attention(x)
224
-
225
- # Pass the attention-modified tensor through the MLP and prediction head
226
- x = self.mlp(x)
227
- logits = self.output_layer(x)
228
-
229
- return logits, attention_weights
230
-
231
- def get_config(self) -> Dict[str, Any]:
279
+ def get_architecture_config(self) -> Dict[str, Any]:
232
280
  """Returns the full configuration of the model."""
233
- config = super().get_config()
281
+ config = super().get_architecture_config()
234
282
  config['num_heads'] = self.num_heads
235
283
  config['attention_dropout'] = self.attention_dropout
236
284
  return config
@@ -247,70 +295,77 @@ class MultiHeadAttentionMLP(_BaseMLP):
247
295
  return f"MultiHeadAttentionMLP(arch: {arch_str})"
248
296
 
249
297
 
250
- class TabularTransformer(nn.Module):
298
+ class TabularTransformer(nn.Module, _ArchitectureHandlerMixin):
251
299
  """
252
300
  A Transformer-based model for tabular data tasks.
253
301
 
254
- This model uses a Feature Tokenizer to convert all input features into a sequence of embeddings, prepends a [CLS] token, and processes the
302
+ This model uses a Feature Tokenizer to convert all input features into a
303
+ sequence of embeddings, prepends a [CLS] token, and processes the
255
304
  sequence with a standard Transformer Encoder.
256
305
  """
257
306
  def __init__(self, *,
307
+ schema: FeatureSchema,
258
308
  out_targets: int,
259
- numerical_indices: List[int],
260
- categorical_map: Dict[int, int],
261
- embedding_dim: int = 32,
309
+ embedding_dim: int = 256,
262
310
  num_heads: int = 8,
263
311
  num_layers: int = 6,
264
- dropout: float = 0.1):
312
+ dropout: float = 0.2):
265
313
  """
266
314
  Args:
267
- out_targets (int): Number of output targets (1 for regression).
268
- numerical_indices (List[int]): Column indices for numerical features.
269
- categorical_map (Dict[int, int]): Maps categorical column index to its cardinality (number of unique categories).
270
- embedding_dim (int): The dimension for all feature embeddings. Must be divisible by num_heads.
271
- num_heads (int): The number of heads in the multi-head attention mechanism.
272
- num_layers (int): The number of sub-encoder-layers in the transformer encoder.
273
- dropout (float): The dropout value.
274
-
275
- Note:
276
- - All arguments are keyword-only to promote clarity.
277
- - Column indices start at 0.
315
+ schema (FeatureSchema):
316
+ The definitive schema object created by `data_exploration.finalize_feature_schema()`.
317
+ out_targets (int):
318
+ Number of output targets (1 for regression).
319
+ embedding_dim (int):
320
+ The dimension for all feature embeddings. Must be divisible by num_heads. Common values: (64, 128, 192, 256, etc.)
321
+ num_heads (int):
322
+ The number of heads in the multi-head attention mechanism. Common values: (4, 8, 16)
323
+ num_layers (int):
324
+ The number of sub-encoder-layers in the transformer encoder. Common values: (4, 8, 12)
325
+ dropout (float):
326
+ The dropout value.
327
+
328
+ ## Note:
278
329
 
279
- ### Data Preparation
280
- The model requires a specific input format. All columns in the input DataFrame must be numerical, but they are treated differently based on the
281
- provided index lists.
282
-
283
- **Nominal Categorical Features** (e.g., 'City', 'Color'): Should **NOT** be one-hot encoded.
284
- Instead, convert them to integer codes (label encoding). You must then provide a dictionary mapping their column indices to
285
- their cardinality (the number of unique categories) via the `categorical_map` parameter.
286
-
287
- **Ordinal & Binary Features** (e.g., 'Low/Medium/High', 'True/False'): Should be treated as **numerical**. Map them to numbers that
288
- represent their state (e.g., `{'Low': 0, 'Medium': 1}` or `{False: 0, True: 1}`). Their column indices should be included in the
289
- `numerical_indices` list.
330
+ **Embedding Dimension:** "Width" of the model. It's the N-dimension vector that will be used to represent each one of the features.
331
+ - Each continuous feature gets its own learnable N-dimension vector.
332
+ - Each categorical feature gets an embedding table that maps every category (e.g., "color=red", "color=blue") to a unique N-dimension vector.
333
+
334
+ **Attention Heads:** Controls the "Multi-Head Attention" mechanism. Instead of looking at all the feature interactions at once, the model splits its attention into N parallel heads.
335
+ - Embedding Dimensions get divided by the number of Attention Heads, resulting in the dimensions assigned per head.
290
336
 
291
- **Standard Numerical Features** (e.g., 'Age', 'Price'): Should be included in the `numerical_indices` list. It is highly recommended to
292
- scale them before training.
337
+ **Number of Layers:** "Depth" of the model. Number of identical `TransformerEncoderLayer` blocks that are stacked on top of each other.
338
+ - Layer 1: The attention heads find simple, direct interactions between the features.
339
+ - Layer 2: Takes the output of Layer 1 and finds interactions between those interactions and so on.
340
+ - Trade-off: More layers are more powerful but are slower to train and more prone to overfitting. If the training loss goes down but the validation loss goes up, you might have too many layers (or need more dropout).
341
+
293
342
  """
294
343
  super().__init__()
344
+
345
+ # --- Get info from schema ---
346
+ in_features = len(schema.feature_names)
347
+ categorical_index_map = schema.categorical_index_map
295
348
 
349
+ # --- Validation ---
350
+ if categorical_index_map and (max(categorical_index_map.keys()) >= in_features):
351
+ _LOGGER.error(f"A categorical index ({max(categorical_index_map.keys())}) is out of bounds for the provided input features ({in_features}).")
352
+ raise ValueError()
353
+
296
354
  # --- Save configuration ---
355
+ self.schema = schema # <-- Save the whole schema
297
356
  self.out_targets = out_targets
298
- self.numerical_indices = numerical_indices
299
- self.categorical_map = categorical_map
300
357
  self.embedding_dim = embedding_dim
301
358
  self.num_heads = num_heads
302
359
  self.num_layers = num_layers
303
360
  self.dropout = dropout
304
361
 
305
- # --- 1. Feature Tokenizer ---
362
+ # --- 1. Feature Tokenizer (now takes the schema) ---
306
363
  self.tokenizer = _FeatureTokenizer(
307
- numerical_indices=numerical_indices,
308
- categorical_map=categorical_map,
364
+ schema=schema,
309
365
  embedding_dim=embedding_dim
310
366
  )
311
367
 
312
368
  # --- 2. CLS Token ---
313
- # A learnable token that will be prepended to the sequence.
314
369
  self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
315
370
 
316
371
  # --- 3. Transformer Encoder ---
@@ -357,25 +412,89 @@ class TabularTransformer(nn.Module):
357
412
 
358
413
  return logits
359
414
 
360
- def get_config(self) -> Dict[str, Any]:
415
+ def get_architecture_config(self) -> Dict[str, Any]:
361
416
  """Returns the full configuration of the model."""
417
+ # Deconstruct schema into a JSON-friendly dict
418
+ # Tuples are saved as lists
419
+ schema_dict = {
420
+ 'feature_names': self.schema.feature_names,
421
+ 'continuous_feature_names': self.schema.continuous_feature_names,
422
+ 'categorical_feature_names': self.schema.categorical_feature_names,
423
+ 'categorical_index_map': self.schema.categorical_index_map,
424
+ 'categorical_mappings': self.schema.categorical_mappings
425
+ }
426
+
362
427
  return {
428
+ 'schema_dict': schema_dict,
363
429
  'out_targets': self.out_targets,
364
- 'numerical_indices': self.numerical_indices,
365
- 'categorical_map': self.categorical_map,
366
430
  'embedding_dim': self.embedding_dim,
367
431
  'num_heads': self.num_heads,
368
432
  'num_layers': self.num_layers,
369
433
  'dropout': self.dropout
370
434
  }
435
+
436
+ @classmethod
437
+ def load(cls: type, file_or_dir: Union[str, Path], verbose: bool = True) -> nn.Module:
438
+ """Loads a model architecture from a JSON file."""
439
+ user_path = make_fullpath(file_or_dir)
440
+
441
+ if user_path.is_dir():
442
+ json_filename = PytorchModelArchitectureKeys.SAVENAME + ".json"
443
+ target_path = make_fullpath(user_path / json_filename, enforce="file")
444
+ elif user_path.is_file():
445
+ target_path = user_path
446
+ else:
447
+ _LOGGER.error(f"Invalid path: '{file_or_dir}'")
448
+ raise IOError()
449
+
450
+ with open(target_path, 'r') as f:
451
+ saved_data = json.load(f)
452
+
453
+ saved_class_name = saved_data[PytorchModelArchitectureKeys.MODEL]
454
+ config = saved_data[PytorchModelArchitectureKeys.CONFIG]
455
+
456
+ if saved_class_name != cls.__name__:
457
+ _LOGGER.error(f"Model class mismatch. File specifies '{saved_class_name}', but '{cls.__name__}' was expected.")
458
+ raise ValueError()
459
+
460
+ # --- RECONSTRUCTION LOGIC ---
461
+ if 'schema_dict' not in config:
462
+ _LOGGER.error("Invalid architecture file: missing 'schema_dict'. This file may be from an older version.")
463
+ raise ValueError("Missing 'schema_dict' in config.")
464
+
465
+ schema_data = config.pop('schema_dict')
466
+
467
+ # Re-hydrate the categorical_index_map
468
+ # JSON saves all dict keys as strings, so we must convert them back to int.
469
+ raw_index_map = schema_data['categorical_index_map']
470
+ if raw_index_map is not None:
471
+ rehydrated_index_map = {int(k): v for k, v in raw_index_map.items()}
472
+ else:
473
+ rehydrated_index_map = None
474
+
475
+ # Re-hydrate the FeatureSchema object
476
+ # JSON deserializes tuples as lists, so we must convert them back.
477
+ schema = FeatureSchema(
478
+ feature_names=tuple(schema_data['feature_names']),
479
+ continuous_feature_names=tuple(schema_data['continuous_feature_names']),
480
+ categorical_feature_names=tuple(schema_data['categorical_feature_names']),
481
+ categorical_index_map=rehydrated_index_map,
482
+ categorical_mappings=schema_data['categorical_mappings']
483
+ )
484
+
485
+ config['schema'] = schema
486
+ # --- End Reconstruction ---
487
+
488
+ model = cls(**config)
489
+ if verbose:
490
+ _LOGGER.info(f"Successfully loaded architecture for '{saved_class_name}'")
491
+ return model
371
492
 
372
493
  def __repr__(self) -> str:
373
494
  """Returns the developer-friendly string representation of the model."""
374
- num_features = len(self.numerical_indices) + len(self.categorical_map)
375
-
376
495
  # Build the architecture string part-by-part
377
496
  parts = [
378
- f"Tokenizer(features={num_features}, dim={self.embedding_dim})",
497
+ f"Tokenizer(features={len(self.schema.feature_names)}, dim={self.embedding_dim})",
379
498
  "[CLS]",
380
499
  f"TransformerEncoder(layers={self.num_layers}, heads={self.num_heads})",
381
500
  f"PredictionHead(outputs={self.out_targets})"
@@ -388,29 +507,41 @@ class TabularTransformer(nn.Module):
388
507
 
389
508
  class _FeatureTokenizer(nn.Module):
390
509
  """
391
- Transforms raw numerical and categorical features from any column order into a sequence of embeddings.
510
+ Transforms raw numerical and categorical features from any column order
511
+ into a sequence of embeddings.
392
512
  """
393
513
  def __init__(self,
394
- numerical_indices: List[int],
395
- categorical_map: Dict[int, int],
514
+ schema: FeatureSchema,
396
515
  embedding_dim: int):
397
516
  """
398
517
  Args:
399
- numerical_indices (List[int]): A list of column indices for the numerical features.
400
- categorical_map (Dict[int, int]): A dictionary mapping each categorical column index to its cardinality (number of unique categories).
401
- embedding_dim (int): The dimension for all feature embeddings.
518
+ schema (FeatureSchema):
519
+ The definitive schema object from data_exploration.
520
+ embedding_dim (int):
521
+ The dimension for all feature embeddings.
402
522
  """
403
523
  super().__init__()
404
524
 
405
- # Unpack the dictionary into separate lists for indices and cardinalities
406
- self.categorical_indices = list(categorical_map.keys())
407
- cardinalities = list(categorical_map.values())
525
+ # --- Get info from schema ---
526
+ categorical_map = schema.categorical_index_map
527
+
528
+ if categorical_map:
529
+ # Unpack the dictionary into separate lists
530
+ self.categorical_indices = list(categorical_map.keys())
531
+ cardinalities = list(categorical_map.values())
532
+ else:
533
+ self.categorical_indices = []
534
+ cardinalities = []
535
+
536
+ # Derive numerical indices by finding what's not categorical
537
+ all_indices = set(range(len(schema.feature_names)))
538
+ categorical_indices_set = set(self.categorical_indices)
539
+ self.numerical_indices = sorted(list(all_indices - categorical_indices_set))
408
540
 
409
- self.numerical_indices = numerical_indices
410
541
  self.embedding_dim = embedding_dim
411
542
 
412
543
  # A learnable embedding for each numerical feature
413
- self.numerical_embeddings = nn.Parameter(torch.randn(len(numerical_indices), embedding_dim))
544
+ self.numerical_embeddings = nn.Parameter(torch.randn(len(self.numerical_indices), embedding_dim))
414
545
 
415
546
  # A standard embedding layer for each categorical feature
416
547
  self.categorical_embeddings = nn.ModuleList(
@@ -432,6 +563,8 @@ class _FeatureTokenizer(nn.Module):
432
563
  # Process categorical features
433
564
  categorical_tokens = []
434
565
  for i, embed_layer in enumerate(self.categorical_embeddings):
566
+ # x_categorical[:, i] selects the i-th categorical column
567
+ # (e.g., all values for the 'color' feature)
435
568
  token = embed_layer(x_categorical[:, i]).unsqueeze(1)
436
569
  categorical_tokens.append(token)
437
570
 
@@ -529,7 +662,7 @@ class _MultiHeadAttentionLayer(nn.Module):
529
662
  return out, attn_weights.squeeze()
530
663
 
531
664
 
532
- class SequencePredictorLSTM(nn.Module):
665
+ class SequencePredictorLSTM(nn.Module, _ArchitectureHandlerMixin):
533
666
  """
534
667
  A simple LSTM-based network for sequence-to-sequence prediction tasks.
535
668
 
@@ -597,7 +730,7 @@ class SequencePredictorLSTM(nn.Module):
597
730
 
598
731
  return predictions
599
732
 
600
- def get_config(self) -> dict:
733
+ def get_architecture_config(self) -> dict:
601
734
  """Returns the configuration of the model."""
602
735
  return {
603
736
  'features': self.features,
@@ -615,76 +748,7 @@ class SequencePredictorLSTM(nn.Module):
615
748
  )
616
749
 
617
750
 
618
- def save_architecture(model: nn.Module, directory: Union[str, Path], verbose: bool=True):
619
- """
620
- Saves a model's architecture to a 'architecture.json' file.
621
-
622
- This function relies on the model having a `get_config()` method that
623
- returns a dictionary of the arguments needed to initialize it.
624
-
625
- Args:
626
- model (nn.Module): The PyTorch model instance to save.
627
- directory (str | Path): The directory to save the JSON file.
628
-
629
- Raises:
630
- AttributeError: If the model does not have a `get_config()` method.
631
- """
632
- if not hasattr(model, 'get_config'):
633
- _LOGGER.error(f"Model '{model.__class__.__name__}' does not have a 'get_config()' method.")
634
- raise AttributeError()
635
-
636
- # Ensure the target directory exists
637
- path_dir = make_fullpath(directory, make=True, enforce="directory")
638
- full_path = path_dir / "architecture.json"
639
-
640
- config = {
641
- 'model_class': model.__class__.__name__,
642
- 'config': model.get_config() # type: ignore
643
- }
644
-
645
- with open(full_path, 'w') as f:
646
- json.dump(config, f, indent=4)
647
-
648
- if verbose:
649
- _LOGGER.info(f"Architecture for '{model.__class__.__name__}' saved to '{path_dir.name}'")
650
-
651
-
652
- def load_architecture(filepath: Union[str, Path], expected_model_class: type, verbose: bool=True) -> nn.Module:
653
- """
654
- Loads a model architecture from a JSON file.
655
-
656
- This function instantiates a model by providing an explicit class to use
657
- and checking that it matches the class name specified in the file.
658
-
659
- Args:
660
- filepath (Union[str, Path]): The path of the JSON architecture file.
661
- expected_model_class (type): The model class expected to load (e.g., MultilayerPerceptron).
662
-
663
- Returns:
664
- nn.Module: An instance of the model with a freshly initialized state.
665
-
666
- Raises:
667
- FileNotFoundError: If the filepath does not exist.
668
- ValueError: If the class name in the file does not match the `expected_model_class`.
669
- """
670
- path_obj = make_fullpath(filepath, enforce="file")
671
-
672
- with open(path_obj, 'r') as f:
673
- saved_data = json.load(f)
674
-
675
- saved_class_name = saved_data['model_class']
676
- config = saved_data['config']
677
-
678
- if saved_class_name != expected_model_class.__name__:
679
- _LOGGER.error(f"Model class mismatch. File specifies '{saved_class_name}', but '{expected_model_class.__name__}' was expected.")
680
- raise ValueError()
681
-
682
- # Create an instance of the model using the provided class and config
683
- model = expected_model_class(**config)
684
- if verbose:
685
- _LOGGER.info(f"Successfully loaded architecture for '{saved_class_name}'")
686
- return model
687
-
751
+ # ---- PyTorch models ---
688
752
 
689
753
  def info():
690
754
  _script_info(__all__)