dragon-ml-toolbox 6.4.1__py3-none-any.whl → 7.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

ml_tools/ML_models.py CHANGED
@@ -1,45 +1,35 @@
1
1
  import torch
2
2
  from torch import nn
3
- from ._script_info import _script_info
4
- from typing import List, Union
3
+ from typing import List, Union, Tuple, Dict, Any
5
4
  from pathlib import Path
6
5
  import json
7
6
  from ._logger import _LOGGER
8
7
  from .path_manager import make_fullpath
9
-
8
+ from ._script_info import _script_info
10
9
 
11
10
  __all__ = [
12
11
  "MultilayerPerceptron",
12
+ "AttentionMLP",
13
+ "MultiHeadAttentionMLP",
14
+ "TabularTransformer",
13
15
  "SequencePredictorLSTM",
14
16
  "save_architecture",
15
17
  "load_architecture"
16
18
  ]
17
19
 
18
20
 
19
- class MultilayerPerceptron(nn.Module):
21
+ class _BaseMLP(nn.Module):
20
22
  """
21
- Creates a versatile Multilayer Perceptron (MLP) for regression or classification tasks.
22
-
23
- This model generates raw output values (logits) suitable for use with loss
24
- functions like `nn.CrossEntropyLoss` (for classification) or `nn.MSELoss`
25
- (for regression).
26
-
27
- Args:
28
- in_features (int): The number of input features (e.g., columns in your data).
29
- out_targets (int): The number of output targets. For regression, this is
30
- typically 1. For classification, it's the number of classes.
31
- hidden_layers (list[int]): A list where each integer represents the
32
- number of neurons in a hidden layer. Defaults to [40, 80, 40].
33
- drop_out (float): The dropout probability for neurons in each hidden
34
- layer. Must be between 0.0 and 1.0. Defaults to 0.2.
35
-
36
- ### Rules of thumb:
37
- - Choose a number of hidden neurons between the size of the input layer and the size of the output layer.
38
- - The number of hidden neurons should be 2/3 the size of the input layer, plus the size of the output layer.
39
- - The number of hidden neurons should be less than twice the size of the input layer.
23
+ A base class for Multilayer Perceptrons.
24
+
25
+ Handles validation, configuration, and the creation of the core MLP layers,
26
+ allowing subclasses to define their own pre-processing and forward pass.
40
27
  """
41
- def __init__(self, in_features: int, out_targets: int,
42
- hidden_layers: List[int] = [40, 80, 40], drop_out: float = 0.2) -> None:
28
+ def __init__(self,
29
+ in_features: int,
30
+ out_targets: int,
31
+ hidden_layers: List[int],
32
+ drop_out: float) -> None:
43
33
  super().__init__()
44
34
 
45
35
  # --- Validation ---
@@ -58,50 +48,485 @@ class MultilayerPerceptron(nn.Module):
58
48
  self.hidden_layers = hidden_layers
59
49
  self.drop_out = drop_out
60
50
 
61
- # --- Build network layers ---
62
- layers = []
51
+ # --- Build the core MLP network ---
52
+ mlp_layers = []
63
53
  current_features = in_features
64
54
  for neurons in hidden_layers:
65
- layers.extend([
55
+ mlp_layers.extend([
66
56
  nn.Linear(current_features, neurons),
67
57
  nn.BatchNorm1d(neurons),
68
58
  nn.ReLU(),
69
59
  nn.Dropout(p=drop_out)
70
60
  ])
71
61
  current_features = neurons
62
+
63
+ self.mlp = nn.Sequential(*mlp_layers)
64
+ # Set a customizable Prediction Head for flexibility, specially in transfer learning and fine-tuning
65
+ self.output_layer = nn.Linear(current_features, out_targets)
72
66
 
73
- # Add the final output layer
74
- layers.append(nn.Linear(current_features, out_targets))
75
-
76
- self._layers = nn.Sequential(*layers)
77
-
78
- def forward(self, x: torch.Tensor) -> torch.Tensor:
79
- """Defines the forward pass of the model."""
80
- return self._layers(x)
81
-
82
- def get_config(self) -> dict:
83
- """Returns the configuration of the model."""
67
+ def get_config(self) -> Dict[str, Any]:
68
+ """Returns the base configuration of the model."""
84
69
  return {
85
70
  'in_features': self.in_features,
86
71
  'out_targets': self.out_targets,
87
72
  'hidden_layers': self.hidden_layers,
88
73
  'drop_out': self.drop_out
89
74
  }
75
+
76
+ def _repr_helper(self, name: str, mlp_layers: list[str]):
77
+ last_layer = self.output_layer
78
+ if isinstance(last_layer, nn.Linear):
79
+ mlp_layers.append(str(last_layer.out_features))
80
+ else:
81
+ mlp_layers.append("Custom Prediction Head")
82
+
83
+ # Creates a string like: 10 -> 40 -> 80 -> 40 -> 2
84
+ arch_str = ' -> '.join(mlp_layers)
85
+
86
+ return f"{name}(arch: {arch_str})"
87
+
88
+
89
+ class MultilayerPerceptron(_BaseMLP):
90
+ """
91
+ Creates a versatile Multilayer Perceptron (MLP) for regression or classification tasks.
92
+
93
+ This model generates raw output values (logits) suitable for use with loss
94
+ functions like `nn.CrossEntropyLoss` (for classification) or `nn.MSELoss`
95
+ (for regression).
96
+ """
97
+ def __init__(self, in_features: int, out_targets: int,
98
+ hidden_layers: List[int] = [256, 128], drop_out: float = 0.2) -> None:
99
+ """
100
+ Args:
101
+ in_features (int): The number of input features (e.g., columns in your data).
102
+ out_targets (int): The number of output targets. For regression, this is
103
+ typically 1. For classification, it's the number of classes.
104
+ hidden_layers (list[int]): A list where each integer represents the
105
+ number of neurons in a hidden layer.
106
+ drop_out (float): The dropout probability for neurons in each hidden
107
+ layer. Must be between 0.0 and 1.0.
108
+
109
+ ### Rules of thumb:
110
+ - Choose a number of hidden neurons between the size of the input layer and the size of the output layer.
111
+ - The number of hidden neurons should be 2/3 the size of the input layer, plus the size of the output layer.
112
+ - The number of hidden neurons should be less than twice the size of the input layer.
113
+ """
114
+ super().__init__(in_features, out_targets, hidden_layers, drop_out)
115
+
116
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
117
+ """Defines the forward pass of the model."""
118
+ x = self.mlp(x)
119
+ logits = self.output_layer(x)
120
+ return logits
90
121
 
91
122
  def __repr__(self) -> str:
92
123
  """Returns the developer-friendly string representation of the model."""
93
124
  # Extracts the number of neurons from each nn.Linear layer
94
- layer_sizes = [layer.in_features for layer in self._layers if isinstance(layer, nn.Linear)]
125
+ layer_sizes = [str(layer.in_features) for layer in self.mlp if isinstance(layer, nn.Linear)]
95
126
 
96
- # Get the last layer and check its type before accessing the attribute
97
- last_layer = self._layers[-1]
98
- if isinstance(last_layer, nn.Linear):
99
- layer_sizes.append(last_layer.out_features)
127
+ return self._repr_helper(name="MultilayerPerceptron", mlp_layers=layer_sizes)
128
+
129
+
130
+ class AttentionMLP(_BaseMLP):
131
+ """
132
+ A Multilayer Perceptron (MLP) that incorporates an Attention layer to dynamically weigh input features.
133
+
134
+ In inference mode use `forward_attention()` to get a tuple with `(output, attention_weights)`
135
+ """
136
+ def __init__(self, in_features: int, out_targets: int,
137
+ hidden_layers: List[int] = [256, 128], drop_out: float = 0.2) -> None:
138
+ """
139
+ Args:
140
+ in_features (int): The number of input features (e.g., columns in your data).
141
+ out_targets (int): The number of output targets. For regression, this is
142
+ typically 1. For classification, it's the number of classes.
143
+ hidden_layers (list[int]): A list where each integer represents the
144
+ number of neurons in a hidden layer.
145
+ drop_out (float): The dropout probability for neurons in each hidden
146
+ layer. Must be between 0.0 and 1.0.
147
+ """
148
+ super().__init__(in_features, out_targets, hidden_layers, drop_out)
149
+ # Attention
150
+ self.attention = _AttentionLayer(in_features)
151
+
152
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
153
+ """
154
+ Defines the standard forward pass.
155
+ """
156
+ logits, _attention_weights = self.forward_attention(x)
157
+ return logits
158
+
159
+ def forward_attention(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
160
+ """
161
+ Returns logits and attention weights
162
+ """
163
+ # The attention layer returns the processed x and the weights
164
+ x, attention_weights = self.attention(x)
100
165
 
101
- # Creates a string like: 10 -> 40 -> 80 -> 40 -> 2
102
- arch_str = ' -> '.join(map(str, layer_sizes))
166
+ # Pass the attention-modified tensor through the MLP
167
+ logits = self.mlp(x)
168
+
169
+ return logits, attention_weights
170
+
171
+ def __repr__(self) -> str:
172
+ """Returns the developer-friendly string representation of the model."""
173
+ # Start with the input features and the attention marker
174
+ arch = [str(self.in_features), "[Attention]"]
175
+
176
+ # Find all other linear layers in the MLP
177
+ for layer in self.mlp[1:]:
178
+ if isinstance(layer, nn.Linear):
179
+ arch.append(str(layer.in_features))
180
+
181
+ return self._repr_helper(name="AttentionMLP", mlp_layers=arch)
182
+
183
+
184
+ class MultiHeadAttentionMLP(_BaseMLP):
185
+ """
186
+ An MLP that incorporates a standard `nn.MultiheadAttention` layer to process
187
+ the input features.
188
+
189
+ In inference mode use `forward_attention()` to get a tuple with `(output, attention_weights)`.
190
+ """
191
+ def __init__(self, in_features: int, out_targets: int,
192
+ hidden_layers: List[int] = [256, 128], drop_out: float = 0.2,
193
+ num_heads: int = 4, attention_dropout: float = 0.1) -> None:
194
+ """
195
+ Args:
196
+ in_features (int): The number of input features.
197
+ out_targets (int): The number of output targets.
198
+ hidden_layers (list[int]): A list of neuron counts for each hidden layer.
199
+ drop_out (float): The dropout probability for the MLP layers.
200
+ num_heads (int): The number of attention heads.
201
+ attention_dropout (float): Dropout probability in the attention layer.
202
+ """
203
+ super().__init__(in_features, out_targets, hidden_layers, drop_out)
204
+ self.num_heads = num_heads
205
+ self.attention_dropout = attention_dropout
206
+
207
+ self.attention = _MultiHeadAttentionLayer(
208
+ num_features=in_features,
209
+ num_heads=num_heads,
210
+ dropout=attention_dropout
211
+ )
212
+
213
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
214
+ """Defines the standard forward pass of the model."""
215
+ logits, _attention_weights = self.forward_attention(x)
216
+ return logits
217
+
218
+ def forward_attention(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
219
+ """
220
+ Returns logits and attention weights.
221
+ """
222
+ # The attention layer returns the processed x and the weights
223
+ x, attention_weights = self.attention(x)
224
+
225
+ # Pass the attention-modified tensor through the MLP and prediction head
226
+ x = self.mlp(x)
227
+ logits = self.output_layer(x)
228
+
229
+ return logits, attention_weights
230
+
231
+ def get_config(self) -> Dict[str, Any]:
232
+ """Returns the full configuration of the model."""
233
+ config = super().get_config()
234
+ config['num_heads'] = self.num_heads
235
+ config['attention_dropout'] = self.attention_dropout
236
+ return config
237
+
238
+ def __repr__(self) -> str:
239
+ """Returns the developer-friendly string representation of the model."""
240
+ mlp_part = " -> ".join(
241
+ [str(self.in_features)] +
242
+ [str(h) for h in self.hidden_layers] +
243
+ [str(self.out_targets)]
244
+ )
245
+ arch_str = f"{self.in_features} -> [MultiHead(h={self.num_heads})] -> {mlp_part}"
246
+
247
+ return f"MultiHeadAttentionMLP(arch: {arch_str})"
248
+
249
+
250
+ class TabularTransformer(nn.Module):
251
+ """
252
+ A Transformer-based model for tabular data tasks.
253
+
254
+ This model uses a Feature Tokenizer to convert all input features into a sequence of embeddings, prepends a [CLS] token, and processes the
255
+ sequence with a standard Transformer Encoder.
256
+ """
257
+ def __init__(self, *,
258
+ out_targets: int,
259
+ numerical_indices: List[int],
260
+ categorical_map: Dict[int, int],
261
+ embedding_dim: int = 32,
262
+ num_heads: int = 8,
263
+ num_layers: int = 6,
264
+ dropout: float = 0.1):
265
+ """
266
+ Args:
267
+ out_targets (int): Number of output targets (1 for regression).
268
+ numerical_indices (List[int]): Column indices for numerical features.
269
+ categorical_map (Dict[int, int]): Maps categorical column index to its cardinality (number of unique categories).
270
+ embedding_dim (int): The dimension for all feature embeddings. Must be divisible by num_heads.
271
+ num_heads (int): The number of heads in the multi-head attention mechanism.
272
+ num_layers (int): The number of sub-encoder-layers in the transformer encoder.
273
+ dropout (float): The dropout value.
274
+
275
+ Note:
276
+ - All arguments are keyword-only to promote clarity.
277
+ - Column indices start at 0.
103
278
 
104
- return f"MultilayerPerceptron(arch: {arch_str})"
279
+ ### Data Preparation
280
+ The model requires a specific input format. All columns in the input DataFrame must be numerical, but they are treated differently based on the
281
+ provided index lists.
282
+
283
+ **Nominal Categorical Features** (e.g., 'City', 'Color'): Should **NOT** be one-hot encoded.
284
+ Instead, convert them to integer codes (label encoding). You must then provide a dictionary mapping their column indices to
285
+ their cardinality (the number of unique categories) via the `categorical_map` parameter.
286
+
287
+ **Ordinal & Binary Features** (e.g., 'Low/Medium/High', 'True/False'): Should be treated as **numerical**. Map them to numbers that
288
+ represent their state (e.g., `{'Low': 0, 'Medium': 1}` or `{False: 0, True: 1}`). Their column indices should be included in the
289
+ `numerical_indices` list.
290
+
291
+ **Standard Numerical Features** (e.g., 'Age', 'Price'): Should be included in the `numerical_indices` list. It is highly recommended to
292
+ scale them before training.
293
+ """
294
+ super().__init__()
295
+
296
+ # --- Save configuration ---
297
+ self.out_targets = out_targets
298
+ self.numerical_indices = numerical_indices
299
+ self.categorical_map = categorical_map
300
+ self.embedding_dim = embedding_dim
301
+ self.num_heads = num_heads
302
+ self.num_layers = num_layers
303
+ self.dropout = dropout
304
+
305
+ # --- 1. Feature Tokenizer ---
306
+ self.tokenizer = _FeatureTokenizer(
307
+ numerical_indices=numerical_indices,
308
+ categorical_map=categorical_map,
309
+ embedding_dim=embedding_dim
310
+ )
311
+
312
+ # --- 2. CLS Token ---
313
+ # A learnable token that will be prepended to the sequence.
314
+ self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
315
+
316
+ # --- 3. Transformer Encoder ---
317
+ encoder_layer = nn.TransformerEncoderLayer(
318
+ d_model=embedding_dim,
319
+ nhead=num_heads,
320
+ dropout=dropout,
321
+ batch_first=True # Crucial for (batch, seq, feature) input
322
+ )
323
+ self.transformer_encoder = nn.TransformerEncoder(
324
+ encoder_layer=encoder_layer,
325
+ num_layers=num_layers
326
+ )
327
+
328
+ # --- 4. Prediction Head ---
329
+ self.output_layer = nn.Linear(embedding_dim, out_targets)
330
+
331
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
332
+ """Defines the forward pass of the model."""
333
+ # Get the batch size for later use
334
+ batch_size = x.shape[0]
335
+
336
+ # 1. Get feature tokens from the tokenizer
337
+ # -> tokens shape: (batch_size, num_features, embedding_dim)
338
+ tokens = self.tokenizer(x)
339
+
340
+ # 2. Prepend the [CLS] token to the sequence
341
+ # -> cls_tokens shape: (batch_size, 1, embedding_dim)
342
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
343
+ # -> full_sequence shape: (batch_size, num_features + 1, embedding_dim)
344
+ full_sequence = torch.cat([cls_tokens, tokens], dim=1)
345
+
346
+ # 3. Pass the full sequence through the Transformer Encoder
347
+ # -> transformer_out shape: (batch_size, num_features + 1, embedding_dim)
348
+ transformer_out = self.transformer_encoder(full_sequence)
349
+
350
+ # 4. Isolate the output of the [CLS] token (it's the first one)
351
+ # -> cls_output shape: (batch_size, embedding_dim)
352
+ cls_output = transformer_out[:, 0]
353
+
354
+ # 5. Pass the [CLS] token's output through the prediction head
355
+ # -> logits shape: (batch_size, out_targets)
356
+ logits = self.output_layer(cls_output)
357
+
358
+ return logits
359
+
360
+ def get_config(self) -> Dict[str, Any]:
361
+ """Returns the full configuration of the model."""
362
+ return {
363
+ 'out_targets': self.out_targets,
364
+ 'numerical_indices': self.numerical_indices,
365
+ 'categorical_map': self.categorical_map,
366
+ 'embedding_dim': self.embedding_dim,
367
+ 'num_heads': self.num_heads,
368
+ 'num_layers': self.num_layers,
369
+ 'dropout': self.dropout
370
+ }
371
+
372
+ def __repr__(self) -> str:
373
+ """Returns the developer-friendly string representation of the model."""
374
+ num_features = len(self.numerical_indices) + len(self.categorical_map)
375
+
376
+ # Build the architecture string part-by-part
377
+ parts = [
378
+ f"Tokenizer(features={num_features}, dim={self.embedding_dim})",
379
+ "[CLS]",
380
+ f"TransformerEncoder(layers={self.num_layers}, heads={self.num_heads})",
381
+ f"PredictionHead(outputs={self.out_targets})"
382
+ ]
383
+
384
+ arch_str = " -> ".join(parts)
385
+
386
+ return f"TabularTransformer(arch: {arch_str})"
387
+
388
+
389
+ class _FeatureTokenizer(nn.Module):
390
+ """
391
+ Transforms raw numerical and categorical features from any column order into a sequence of embeddings.
392
+ """
393
+ def __init__(self,
394
+ numerical_indices: List[int],
395
+ categorical_map: Dict[int, int],
396
+ embedding_dim: int):
397
+ """
398
+ Args:
399
+ numerical_indices (List[int]): A list of column indices for the numerical features.
400
+ categorical_map (Dict[int, int]): A dictionary mapping each categorical column index to its cardinality (number of unique categories).
401
+ embedding_dim (int): The dimension for all feature embeddings.
402
+ """
403
+ super().__init__()
404
+
405
+ # Unpack the dictionary into separate lists for indices and cardinalities
406
+ self.categorical_indices = list(categorical_map.keys())
407
+ cardinalities = list(categorical_map.values())
408
+
409
+ self.numerical_indices = numerical_indices
410
+ self.embedding_dim = embedding_dim
411
+
412
+ # A learnable embedding for each numerical feature
413
+ self.numerical_embeddings = nn.Parameter(torch.randn(len(numerical_indices), embedding_dim))
414
+
415
+ # A standard embedding layer for each categorical feature
416
+ self.categorical_embeddings = nn.ModuleList(
417
+ [nn.Embedding(num_embeddings=c, embedding_dim=embedding_dim) for c in cardinalities]
418
+ )
419
+
420
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
421
+ """
422
+ Processes features from a single input tensor and concatenates them
423
+ into a sequence of tokens.
424
+ """
425
+ # Select the correct columns for each type using the stored indices
426
+ x_numerical = x[:, self.numerical_indices].float()
427
+ x_categorical = x[:, self.categorical_indices].long()
428
+
429
+ # Process numerical features
430
+ numerical_tokens = x_numerical.unsqueeze(-1) * self.numerical_embeddings
431
+
432
+ # Process categorical features
433
+ categorical_tokens = []
434
+ for i, embed_layer in enumerate(self.categorical_embeddings):
435
+ token = embed_layer(x_categorical[:, i]).unsqueeze(1)
436
+ categorical_tokens.append(token)
437
+
438
+ # Concatenate all tokens into a single sequence
439
+ if not self.categorical_indices:
440
+ all_tokens = numerical_tokens
441
+ elif not self.numerical_indices:
442
+ all_tokens = torch.cat(categorical_tokens, dim=1)
443
+ else:
444
+ all_categorical_tokens = torch.cat(categorical_tokens, dim=1)
445
+ all_tokens = torch.cat([numerical_tokens, all_categorical_tokens], dim=1)
446
+
447
+ return all_tokens
448
+
449
+
450
+ class _AttentionLayer(nn.Module):
451
+ """
452
+ Calculates attention weights and applies them to the input features, incorporating a residual connection for improved stability and performance.
453
+
454
+ Returns both the final output and the weights for interpretability.
455
+ """
456
+ def __init__(self, num_features: int):
457
+ super().__init__()
458
+ # The hidden layer size is a hyperparameter
459
+ hidden_size = max(16, num_features // 4)
460
+
461
+ # Learn to produce attention scores
462
+ self.attention_net = nn.Sequential(
463
+ nn.Linear(num_features, hidden_size),
464
+ nn.Tanh(),
465
+ nn.Linear(hidden_size, num_features) # Output one score per feature
466
+ )
467
+ self.softmax = nn.Softmax(dim=1)
468
+
469
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
470
+ # x shape: (batch_size, num_features)
471
+
472
+ # Get one raw "importance" score per feature
473
+ attention_scores = self.attention_net(x)
474
+
475
+ # Apply the softmax module to get weights that sum to 1
476
+ attention_weights = self.softmax(attention_scores)
477
+
478
+ # Weighted features (attention mechanism's output)
479
+ weighted_features = x * attention_weights
480
+
481
+ # Residual connection
482
+ residual_connection = x + weighted_features
483
+
484
+ return residual_connection, attention_weights
485
+
486
+
487
+ class _MultiHeadAttentionLayer(nn.Module):
488
+ """
489
+ A wrapper for the standard `torch.nn.MultiheadAttention` layer.
490
+
491
+ This layer treats the entire input feature vector as a single item in a
492
+ sequence and applies self-attention to it. It is followed by a residual
493
+ connection and layer normalization, which is a standard block in
494
+ Transformer-style models.
495
+ """
496
+ def __init__(self, num_features: int, num_heads: int, dropout: float):
497
+ super().__init__()
498
+ self.attention = nn.MultiheadAttention(
499
+ embed_dim=num_features,
500
+ num_heads=num_heads,
501
+ dropout=dropout,
502
+ batch_first=True # Crucial for (batch, seq, feature) input
503
+ )
504
+ self.layer_norm = nn.LayerNorm(num_features)
505
+
506
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
507
+ # x shape: (batch_size, num_features)
508
+
509
+ # nn.MultiheadAttention expects a sequence dimension.
510
+ # We add a sequence dimension of length 1.
511
+ # x_reshaped shape: (batch_size, 1, num_features)
512
+ x_reshaped = x.unsqueeze(1)
513
+
514
+ # Apply self-attention. query, key, and value are all the same.
515
+ # attn_output shape: (batch_size, 1, num_features)
516
+ # attn_weights shape: (batch_size, 1, 1)
517
+ attn_output, attn_weights = self.attention(
518
+ query=x_reshaped,
519
+ key=x_reshaped,
520
+ value=x_reshaped,
521
+ need_weights=True,
522
+ average_attn_weights=True # Average weights across heads
523
+ )
524
+
525
+ # Add residual connection and apply layer normalization (Post-LN)
526
+ out = self.layer_norm(x + attn_output.squeeze(1))
527
+
528
+ # Squeeze weights for a consistent output shape
529
+ return out, attn_weights.squeeze()
105
530
 
106
531
 
107
532
  class SequencePredictorLSTM(nn.Module):