nextrec 0.4.20__py3-none-any.whl → 0.4.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/activation.py +9 -4
  3. nextrec/basic/callback.py +39 -87
  4. nextrec/basic/features.py +149 -28
  5. nextrec/basic/heads.py +3 -1
  6. nextrec/basic/layers.py +375 -94
  7. nextrec/basic/loggers.py +236 -39
  8. nextrec/basic/model.py +259 -326
  9. nextrec/basic/session.py +2 -2
  10. nextrec/basic/summary.py +323 -0
  11. nextrec/cli.py +3 -3
  12. nextrec/data/data_processing.py +45 -1
  13. nextrec/data/dataloader.py +2 -2
  14. nextrec/data/preprocessor.py +2 -2
  15. nextrec/loss/__init__.py +0 -4
  16. nextrec/loss/grad_norm.py +3 -3
  17. nextrec/models/multi_task/esmm.py +4 -6
  18. nextrec/models/multi_task/mmoe.py +4 -6
  19. nextrec/models/multi_task/ple.py +6 -8
  20. nextrec/models/multi_task/poso.py +5 -7
  21. nextrec/models/multi_task/share_bottom.py +6 -8
  22. nextrec/models/ranking/afm.py +4 -6
  23. nextrec/models/ranking/autoint.py +4 -6
  24. nextrec/models/ranking/dcn.py +8 -7
  25. nextrec/models/ranking/dcn_v2.py +4 -6
  26. nextrec/models/ranking/deepfm.py +5 -7
  27. nextrec/models/ranking/dien.py +8 -7
  28. nextrec/models/ranking/din.py +8 -7
  29. nextrec/models/ranking/eulernet.py +5 -7
  30. nextrec/models/ranking/ffm.py +5 -7
  31. nextrec/models/ranking/fibinet.py +4 -6
  32. nextrec/models/ranking/fm.py +4 -6
  33. nextrec/models/ranking/lr.py +4 -6
  34. nextrec/models/ranking/masknet.py +8 -9
  35. nextrec/models/ranking/pnn.py +4 -6
  36. nextrec/models/ranking/widedeep.py +5 -7
  37. nextrec/models/ranking/xdeepfm.py +8 -7
  38. nextrec/models/retrieval/dssm.py +4 -10
  39. nextrec/models/retrieval/dssm_v2.py +0 -6
  40. nextrec/models/retrieval/mind.py +4 -10
  41. nextrec/models/retrieval/sdm.py +4 -10
  42. nextrec/models/retrieval/youtube_dnn.py +4 -10
  43. nextrec/models/sequential/hstu.py +1 -3
  44. nextrec/utils/__init__.py +17 -15
  45. nextrec/utils/config.py +15 -5
  46. nextrec/utils/console.py +2 -2
  47. nextrec/utils/feature.py +2 -2
  48. nextrec/{loss/loss_utils.py → utils/loss.py} +21 -36
  49. nextrec/utils/torch_utils.py +57 -112
  50. nextrec/utils/types.py +63 -0
  51. {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/METADATA +8 -6
  52. nextrec-0.4.22.dist-info/RECORD +81 -0
  53. nextrec-0.4.20.dist-info/RECORD +0 -79
  54. {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/WHEEL +0 -0
  55. {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/entry_points.txt +0 -0
  56. {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/layers.py CHANGED
@@ -1,8 +1,8 @@
1
1
  """
2
- Layer implementations used across NextRec models.
2
+ Layer implementations used across NextRec.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 20/12/2025
5
+ Checkpoint: edit on 27/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -10,7 +10,9 @@ from __future__ import annotations
10
10
 
11
11
  from collections import OrderedDict
12
12
  from itertools import combinations
13
+ from typing import Literal
13
14
 
15
+ import math
14
16
  import torch
15
17
  import torch.nn as nn
16
18
  import torch.nn.functional as F
@@ -23,7 +25,9 @@ from nextrec.utils.torch_utils import get_initializer
23
25
  class PredictionLayer(nn.Module):
24
26
  def __init__(
25
27
  self,
26
- task_type: str | list[str] = "binary",
28
+ task_type: (
29
+ Literal["binary", "regression"] | list[Literal["binary", "regression"]]
30
+ ) = "binary",
27
31
  task_dims: int | list[int] | None = None,
28
32
  use_bias: bool = True,
29
33
  return_logits: bool = False,
@@ -51,7 +55,8 @@ class PredictionLayer(nn.Module):
51
55
  dims = list(task_dims)
52
56
  if len(dims) not in (1, len(self.task_types)):
53
57
  raise ValueError(
54
- "[PredictionLayer Error]: task_dims must be None, a single int (shared), or a sequence of the same length as task_type."
58
+ "[PredictionLayer Error]: task_dims must be None, a single int (shared), "
59
+ "or a sequence of the same length as task_type."
55
60
  )
56
61
  if len(dims) == 1 and len(self.task_types) > 1:
57
62
  dims = dims * len(self.task_types)
@@ -61,7 +66,7 @@ class PredictionLayer(nn.Module):
61
66
 
62
67
  # slice offsets per task
63
68
  start = 0
64
- self.task_slices: list[tuple[int, int]] = []
69
+ self.task_slices = []
65
70
  for dim in self.task_dims:
66
71
  if dim < 1:
67
72
  raise ValueError("Each task dimension must be >= 1.")
@@ -106,53 +111,96 @@ class EmbeddingLayer(nn.Module):
106
111
  super().__init__()
107
112
  self.features = list(features)
108
113
  self.embed_dict = nn.ModuleDict()
109
- self.dense_transforms = nn.ModuleDict()
110
- self.dense_input_dims: dict[str, int] = {}
114
+ self.dense_transforms = nn.ModuleDict() # dense feature projection layers
115
+ self.dense_input_dims = {}
116
+ self.sequence_poolings = nn.ModuleDict()
111
117
 
112
118
  for feature in self.features:
113
119
  if isinstance(feature, (SparseFeature, SequenceFeature)):
114
- if feature.embedding_name in self.embed_dict:
115
- continue
116
- if getattr(feature, "pretrained_weight", None) is not None:
117
- weight = feature.pretrained_weight # type: ignore[assignment]
118
- if weight.shape != (feature.vocab_size, feature.embedding_dim): # type: ignore[assignment]
119
- raise ValueError(f"[EmbeddingLayer Error]: Pretrained weight for '{feature.embedding_name}' has shape {weight.shape}, expected ({feature.vocab_size}, {feature.embedding_dim}).") # type: ignore[assignment]
120
- embedding = nn.Embedding.from_pretrained(embeddings=weight, freeze=feature.freeze_pretrained, padding_idx=feature.padding_idx) # type: ignore[assignment]
121
- embedding.weight.requires_grad = feature.trainable and not feature.freeze_pretrained # type: ignore[assignment]
122
- else:
123
- embedding = nn.Embedding(
124
- num_embeddings=feature.vocab_size,
125
- embedding_dim=feature.embedding_dim,
126
- padding_idx=feature.padding_idx,
127
- )
128
- embedding.weight.requires_grad = feature.trainable
129
- initialization = get_initializer(
130
- init_type=feature.init_type,
131
- activation="linear",
132
- param=feature.init_params,
133
- )
134
- initialization(embedding.weight)
135
- self.embed_dict[feature.embedding_name] = embedding
120
+ if feature.embedding_name not in self.embed_dict:
121
+ if feature.pretrained_weight is not None:
122
+ weight = feature.pretrained_weight
123
+ if weight.shape != (
124
+ feature.vocab_size,
125
+ feature.embedding_dim,
126
+ ):
127
+ raise ValueError(
128
+ f"[EmbeddingLayer Error]: Pretrained weight for '{feature.embedding_name}' has shape {weight.shape}, expected ({feature.vocab_size}, {feature.embedding_dim})."
129
+ )
130
+ embedding = nn.Embedding.from_pretrained(
131
+ embeddings=weight,
132
+ freeze=feature.freeze_pretrained,
133
+ padding_idx=feature.padding_idx,
134
+ )
135
+ embedding.weight.requires_grad = (
136
+ feature.trainable and not feature.freeze_pretrained
137
+ )
138
+ else:
139
+ embedding = nn.Embedding(
140
+ num_embeddings=feature.vocab_size,
141
+ embedding_dim=feature.embedding_dim,
142
+ padding_idx=feature.padding_idx,
143
+ )
144
+ embedding.weight.requires_grad = feature.trainable
145
+ initialization = get_initializer(
146
+ init_type=feature.init_type, # type: ignore[arg-type]
147
+ activation="linear",
148
+ param=feature.init_params,
149
+ )
150
+ initialization(embedding.weight)
151
+ self.embed_dict[feature.embedding_name] = embedding
152
+
136
153
  elif isinstance(feature, DenseFeature):
137
- if not feature.use_embedding:
138
- self.dense_input_dims[feature.name] = max(
139
- int(getattr(feature, "input_dim", 1)), 1
140
- )
141
- continue
154
+ if not feature.use_projection:
155
+ input_dim = feature.input_dim
156
+ self.dense_input_dims[feature.name] = max(int(input_dim), 1)
157
+ continue # skip if no projection is needed
142
158
  if feature.name in self.dense_transforms:
143
- continue
144
- in_dim = max(int(getattr(feature, "input_dim", 1)), 1)
145
- out_dim = max(int(getattr(feature, "embedding_dim", None) or in_dim), 1)
146
- dense_linear = nn.Linear(in_dim, out_dim, bias=True)
159
+ continue # skip if already created
160
+
161
+ input_dim = feature.input_dim
162
+ out_dim = (
163
+ feature.embedding_dim
164
+ if feature.embedding_dim is not None
165
+ else input_dim
166
+ )
167
+
168
+ dense_linear = nn.Linear(input_dim, out_dim, bias=True)
147
169
  nn.init.xavier_uniform_(dense_linear.weight)
148
170
  nn.init.zeros_(dense_linear.bias)
149
171
  self.dense_transforms[feature.name] = dense_linear
150
- self.dense_input_dims[feature.name] = in_dim
172
+ self.dense_input_dims[feature.name] = input_dim
151
173
  else:
152
174
  raise TypeError(
153
175
  f"[EmbeddingLayer Error]: Unsupported feature type: {type(feature)}"
154
176
  )
155
- self.output_dim = self.compute_output_dim()
177
+ if isinstance(feature, SequenceFeature):
178
+ if feature.name in self.sequence_poolings:
179
+ continue
180
+ if feature.combiner == "mean":
181
+ pooling_layer = AveragePooling()
182
+ elif feature.combiner == "sum":
183
+ pooling_layer = SumPooling()
184
+ elif feature.combiner == "concat":
185
+ pooling_layer = ConcatPooling()
186
+ elif feature.combiner == "dot_attention":
187
+ pooling_layer = DotProductAttentionPooling(feature.embedding_dim)
188
+ elif feature.combiner == "self_attention":
189
+ if feature.embedding_dim % 4 != 0:
190
+ raise ValueError(
191
+ f"[EmbeddingLayer Error]: self_attention requires embedding_dim divisible by 4, got {feature.embedding_dim}."
192
+ )
193
+ pooling_layer = SelfAttentionPooling(
194
+ feature.embedding_dim, num_heads=4, dropout=0.0
195
+ )
196
+ else:
197
+ raise ValueError(
198
+ f"[EmbeddingLayer Error]: Unknown combiner for {feature.name}: {feature.combiner}"
199
+ )
200
+ self.sequence_poolings[feature.name] = pooling_layer
201
+ self.output_dim = (
202
+ self.compute_output_dim()
203
+ ) # output dimension of the embedding layer
156
204
 
157
205
  def forward(
158
206
  self,
@@ -160,8 +208,8 @@ class EmbeddingLayer(nn.Module):
160
208
  features: list[object],
161
209
  squeeze_dim: bool = False,
162
210
  ) -> torch.Tensor:
163
- sparse_embeds: list[torch.Tensor] = []
164
- dense_embeds: list[torch.Tensor] = []
211
+ sparse_embeds = []
212
+ dense_embeds = []
165
213
 
166
214
  for feature in features:
167
215
  if isinstance(feature, SparseFeature):
@@ -175,17 +223,7 @@ class EmbeddingLayer(nn.Module):
175
223
 
176
224
  embed = self.embed_dict[feature.embedding_name]
177
225
  seq_emb = embed(seq_input) # [B, seq_len, emb_dim]
178
-
179
- if feature.combiner == "mean":
180
- pooling_layer = AveragePooling()
181
- elif feature.combiner == "sum":
182
- pooling_layer = SumPooling()
183
- elif feature.combiner == "concat":
184
- pooling_layer = ConcatPooling()
185
- else:
186
- raise ValueError(
187
- f"[EmbeddingLayer Error]: Unknown combiner for {feature.name}: {feature.combiner}"
188
- )
226
+ pooling_layer = self.sequence_poolings[feature.name]
189
227
  feature_mask = InputMask()(x, feature, seq_input)
190
228
  sparse_embeds.append(pooling_layer(seq_emb, feature_mask).unsqueeze(1))
191
229
 
@@ -238,17 +276,16 @@ class EmbeddingLayer(nn.Module):
238
276
  )
239
277
  value = x[feature.name].float()
240
278
  if value.dim() == 1:
241
- value = value.unsqueeze(-1)
279
+ value = value.unsqueeze(-1) # [B, 1]
242
280
  else:
243
- value = value.view(value.size(0), -1)
244
- expected_in_dim = self.dense_input_dims.get(
245
- feature.name, max(int(getattr(feature, "input_dim", 1)), 1)
246
- )
247
- if value.shape[1] != expected_in_dim:
281
+ value = value.view(value.size(0), -1) # [B, input_dim]
282
+ input_dim = feature.input_dim
283
+ assert_input_dim = self.dense_input_dims.get(feature.name, input_dim)
284
+ if value.shape[1] != assert_input_dim:
248
285
  raise ValueError(
249
- f"[EmbeddingLayer Error]:Dense feature '{feature.name}' expects {expected_in_dim} inputs but got {value.shape[1]}."
286
+ f"[EmbeddingLayer Error]:Dense feature '{feature.name}' expects {assert_input_dim} inputs but got {value.shape[1]}."
250
287
  )
251
- if not feature.use_embedding:
288
+ if not feature.use_projection:
252
289
  return value
253
290
  dense_layer = self.dense_transforms[feature.name]
254
291
  return dense_layer(value)
@@ -257,25 +294,25 @@ class EmbeddingLayer(nn.Module):
257
294
  self,
258
295
  features: list[DenseFeature | SequenceFeature | SparseFeature] | None = None,
259
296
  ) -> int:
260
- candidates = list(features) if features is not None else self.features
261
- unique_feats = OrderedDict((feat.name, feat) for feat in candidates) # type: ignore[assignment]
297
+ """Compute the output dimension of the embedding layer."""
298
+ all_features = list(features) if features is not None else self.features
299
+ unique_feats = OrderedDict((feat.name, feat) for feat in all_features)
262
300
  dim = 0
263
301
  for feat in unique_feats.values():
264
302
  if isinstance(feat, DenseFeature):
265
- in_dim = max(int(getattr(feat, "input_dim", 1)), 1)
266
- if getattr(feat, "use_embedding", False):
267
- emb_dim = getattr(feat, "embedding_dim", None)
268
- out_dim = max(int(emb_dim), 1) if emb_dim else in_dim
303
+ if feat.use_projection:
304
+ out_dim = feat.embedding_dim
269
305
  else:
270
- out_dim = in_dim
306
+ out_dim = feat.input_dim
271
307
  dim += out_dim
272
308
  elif isinstance(feat, SequenceFeature) and feat.combiner == "concat":
273
309
  dim += feat.embedding_dim * feat.max_len
274
310
  else:
275
- dim += feat.embedding_dim # type: ignore[assignment]
311
+ dim += feat.embedding_dim
276
312
  return dim
277
313
 
278
314
  def get_input_dim(self, features: list[object] | None = None) -> int:
315
+ """Get the input dimension for the network based on embedding layer's output dimension."""
279
316
  return self.compute_output_dim(features) # type: ignore[assignment]
280
317
 
281
318
  @property
@@ -339,7 +376,8 @@ class ConcatPooling(nn.Module):
339
376
  def forward(
340
377
  self, x: torch.Tensor, mask: torch.Tensor | None = None
341
378
  ) -> torch.Tensor:
342
- return x.flatten(start_dim=1, end_dim=2)
379
+ pooled = x.flatten(start_dim=1, end_dim=2)
380
+ return pooled
343
381
 
344
382
 
345
383
  class AveragePooling(nn.Module):
@@ -349,12 +387,15 @@ class AveragePooling(nn.Module):
349
387
  def forward(
350
388
  self, x: torch.Tensor, mask: torch.Tensor | None = None
351
389
  ) -> torch.Tensor:
390
+ # mask: matrix with 0/1 values for padding positions
352
391
  if mask is None:
353
- return torch.mean(x, dim=1)
392
+ pooled = torch.mean(x, dim=1)
354
393
  else:
394
+ # 0/1 matrix * x
355
395
  sum_pooling_matrix = torch.bmm(mask, x).squeeze(1)
356
396
  non_padding_length = mask.sum(dim=-1)
357
- return sum_pooling_matrix / (non_padding_length.float() + 1e-16)
397
+ pooled = sum_pooling_matrix / (non_padding_length.float() + 1e-16)
398
+ return pooled
358
399
 
359
400
 
360
401
  class SumPooling(nn.Module):
@@ -365,9 +406,184 @@ class SumPooling(nn.Module):
365
406
  self, x: torch.Tensor, mask: torch.Tensor | None = None
366
407
  ) -> torch.Tensor:
367
408
  if mask is None:
368
- return torch.sum(x, dim=1)
409
+ pooled = torch.sum(x, dim=1)
369
410
  else:
370
- return torch.bmm(mask, x).squeeze(1)
411
+ pooled = torch.bmm(mask, x).squeeze(1)
412
+ return pooled
413
+
414
+
415
+ class DotProductAttentionPooling(nn.Module):
416
+ """
417
+ Dot-product attention pooling with a learnable global query vector.
418
+
419
+ Input:
420
+ x: [B, L, D]
421
+ mask: [B, 1, L] or [B, L] with 1 for valid tokens, 0 for padding
422
+ Output:
423
+ pooled: [B, D]
424
+ """
425
+
426
+ def __init__(self, embedding_dim: int, scale: bool = True, dropout: float = 0.0):
427
+ super().__init__()
428
+ self.embedding_dim = embedding_dim
429
+ self.scale = scale
430
+ self.dropout = nn.Dropout(dropout)
431
+ self.query = nn.Parameter(torch.empty(embedding_dim))
432
+ nn.init.xavier_uniform_(self.query.view(1, -1))
433
+
434
+ def forward(
435
+ self, x: torch.Tensor, mask: torch.Tensor | None = None
436
+ ) -> torch.Tensor:
437
+ if x.dim() != 3:
438
+ raise ValueError(
439
+ f"[DotProductAttentionPooling Error]: x must be [B,L,D], got {x.shape}"
440
+ )
441
+ B, L, D = x.shape
442
+ if D != self.embedding_dim:
443
+ raise ValueError(
444
+ f"[DotProductAttentionPooling Error]: embedding_dim mismatch: {D} vs {self.embedding_dim}"
445
+ )
446
+
447
+ q = self.query.view(1, 1, D) # [1,1,D]
448
+ scores = (x * q).sum(dim=-1) # [B,L]
449
+ if self.scale:
450
+ scores = scores / math.sqrt(D)
451
+
452
+ if mask is not None:
453
+ if mask.dim() == 3: # [B,1,L] or [B,L,1]
454
+ if mask.size(1) == 1:
455
+ mask_ = mask.squeeze(1) # [B,L]
456
+ elif mask.size(-1) == 1:
457
+ mask_ = mask.squeeze(-1) # [B,L]
458
+ else:
459
+ raise ValueError(
460
+ f"[DotProductAttentionPooling Error]: bad mask shape: {mask.shape}"
461
+ )
462
+ elif mask.dim() == 2:
463
+ mask_ = mask
464
+ else:
465
+ raise ValueError(
466
+ f"[DotProductAttentionPooling Error]: bad mask dim: {mask.dim()}"
467
+ )
468
+
469
+ mask_ = mask_.to(dtype=torch.bool)
470
+ scores = scores.masked_fill(~mask_, float("-inf")) # mask padding positions
471
+
472
+ attn = torch.softmax(scores, dim=-1) # [B,L]
473
+ attn = self.dropout(attn)
474
+ attn = torch.nan_to_num(attn, nan=0.0) # handle all -inf case
475
+ pooled = torch.bmm(attn.unsqueeze(1), x).squeeze(1) # [B,D]
476
+ return pooled
477
+
478
+
479
+ class SelfAttentionPooling(nn.Module):
480
+ """
481
+ Self-attention (MHA) to contextualize tokens, then attention pooling to [B,D].
482
+
483
+ Input:
484
+ x: [B, L, D]
485
+ mask: [B, 1, L] or [B, L] with 1 for valid tokens, 0 for padding
486
+ Output:
487
+ pooled: [B, D]
488
+ """
489
+
490
+ def __init__(
491
+ self,
492
+ embedding_dim: int,
493
+ num_heads: int = 2,
494
+ dropout: float = 0.0,
495
+ use_residual: bool = True,
496
+ use_layer_norm: bool = True,
497
+ use_ffn: bool = False,
498
+ ):
499
+ super().__init__()
500
+ if embedding_dim % num_heads != 0:
501
+ raise ValueError(
502
+ f"[SelfAttentionPooling Error]: embedding_dim ({embedding_dim}) must be divisible by num_heads ({num_heads})"
503
+ )
504
+
505
+ self.embedding_dim = embedding_dim
506
+ self.use_residual = use_residual
507
+ self.use_layer_norm = use_layer_norm
508
+ self.dropout = nn.Dropout(dropout)
509
+ self.mha = nn.MultiheadAttention(
510
+ embed_dim=embedding_dim,
511
+ num_heads=num_heads,
512
+ dropout=dropout,
513
+ batch_first=True,
514
+ )
515
+ if use_layer_norm:
516
+ self.layer_norm_1 = nn.LayerNorm(embedding_dim)
517
+ else:
518
+ self.layer_norm_1 = None
519
+
520
+ self.use_ffn = use_ffn
521
+ if use_ffn:
522
+ self.ffn = nn.Sequential(
523
+ nn.Linear(embedding_dim, 4 * embedding_dim),
524
+ nn.ReLU(),
525
+ nn.Dropout(dropout),
526
+ nn.Linear(4 * embedding_dim, embedding_dim),
527
+ )
528
+ if use_layer_norm:
529
+ self.layer_norm_2 = nn.LayerNorm(embedding_dim)
530
+ else:
531
+ self.layer_norm_2 = None
532
+
533
+ self.pool = DotProductAttentionPooling(
534
+ embedding_dim=embedding_dim, scale=True, dropout=dropout
535
+ )
536
+
537
+ def forward(
538
+ self, x: torch.Tensor, mask: torch.Tensor | None = None
539
+ ) -> torch.Tensor:
540
+ if x.dim() != 3:
541
+ raise ValueError(
542
+ f"[SelfAttentionPooling Error]: x must be [B,L,D], got {x.shape}"
543
+ )
544
+ B, L, D = x.shape
545
+ if D != self.embedding_dim:
546
+ raise ValueError(
547
+ f"[SelfAttentionPooling Error]: embedding_dim mismatch: {D} vs {self.embedding_dim}"
548
+ )
549
+
550
+ key_padding_mask = None
551
+ if mask is not None:
552
+ if mask.dim() == 3:
553
+ if mask.size(1) == 1:
554
+ mask_ = mask.squeeze(1) # [B,L]
555
+ elif mask.size(-1) == 1:
556
+ mask_ = mask.squeeze(-1) # [B,L]
557
+ else:
558
+ raise ValueError(
559
+ f"[SelfAttentionPooling Error]: bad mask shape: {mask.shape}"
560
+ )
561
+ elif mask.dim() == 2:
562
+ mask_ = mask
563
+ else:
564
+ raise ValueError(
565
+ f"[SelfAttentionPooling Error]: bad mask dim: {mask.dim()}"
566
+ )
567
+ key_padding_mask = ~mask_.to(dtype=torch.bool) # True = padding
568
+
569
+ attn_out, _ = self.mha(
570
+ x, x, x, key_padding_mask=key_padding_mask, need_weights=False
571
+ )
572
+ if self.use_residual:
573
+ x = x + self.dropout(attn_out)
574
+ else:
575
+ x = self.dropout(attn_out)
576
+ if self.layer_norm_1 is not None:
577
+ x = self.layer_norm_1(x)
578
+
579
+ if self.use_ffn:
580
+ ffn_out = self.ffn(x)
581
+ x = x + self.dropout(ffn_out)
582
+ if self.layer_norm_2 is not None:
583
+ x = self.layer_norm_2(x)
584
+
585
+ pooled = self.pool(x, mask=mask)
586
+ return pooled
371
587
 
372
588
 
373
589
  class MLP(nn.Module):
@@ -377,10 +593,45 @@ class MLP(nn.Module):
377
593
  output_layer: bool = True,
378
594
  dims: list[int] | None = None,
379
595
  dropout: float = 0.0,
380
- activation: str = "relu",
596
+ activation: Literal[
597
+ "dice",
598
+ "relu",
599
+ "relu6",
600
+ "elu",
601
+ "selu",
602
+ "leaky_relu",
603
+ "prelu",
604
+ "gelu",
605
+ "sigmoid",
606
+ "tanh",
607
+ "softplus",
608
+ "softsign",
609
+ "hardswish",
610
+ "mish",
611
+ "silu",
612
+ "swish",
613
+ "hardsigmoid",
614
+ "tanhshrink",
615
+ "softshrink",
616
+ "none",
617
+ "linear",
618
+ "identity",
619
+ ] = "relu",
381
620
  use_norm: bool = True,
382
- norm_type: str = "layer_norm",
621
+ norm_type: Literal["batch_norm", "layer_norm"] = "layer_norm",
383
622
  ):
623
+ """
624
+ Multi-Layer Perceptron (MLP) module.
625
+
626
+ Args:
627
+ input_dim: Dimension of the input features.
628
+ output_layer: Whether to include the final output layer. If False, the MLP will output the last hidden layer, else it will output a single value.
629
+ dims: List of hidden layer dimensions. If None, no hidden layers are added.
630
+ dropout: Dropout rate between layers.
631
+ activation: Activation function to use between layers.
632
+ use_norm: Whether to use normalization layers.
633
+ norm_type: Type of normalization to use ("batch_norm" or "layer_norm").
634
+ """
384
635
  super().__init__()
385
636
  if dims is None:
386
637
  dims = []
@@ -457,7 +708,12 @@ class SENETLayer(nn.Module):
457
708
 
458
709
  class BiLinearInteractionLayer(nn.Module):
459
710
  def __init__(
460
- self, input_dim: int, num_fields: int, bilinear_type: str = "field_interaction"
711
+ self,
712
+ input_dim: int,
713
+ num_fields: int,
714
+ bilinear_type: Literal[
715
+ "field_all", "field_each", "field_interaction"
716
+ ] = "field_interaction",
461
717
  ):
462
718
  super(BiLinearInteractionLayer, self).__init__()
463
719
  self.bilinear_type = bilinear_type
@@ -531,14 +787,16 @@ class MultiHeadSelfAttention(nn.Module):
531
787
  self.use_residual = use_residual
532
788
  self.dropout_rate = dropout
533
789
 
534
- self.W_Q = nn.Linear(
790
+ self.q_proj = nn.Linear(
535
791
  embedding_dim, embedding_dim, bias=False
536
792
  ) # Query projection
537
- self.W_K = nn.Linear(embedding_dim, embedding_dim, bias=False) # Key projection
538
- self.W_V = nn.Linear(
793
+ self.k_proj = nn.Linear(
794
+ embedding_dim, embedding_dim, bias=False
795
+ ) # Key projection
796
+ self.v_proj = nn.Linear(
539
797
  embedding_dim, embedding_dim, bias=False
540
798
  ) # Value projection
541
- self.W_O = nn.Linear(
799
+ self.out_proj = nn.Linear(
542
800
  embedding_dim, embedding_dim, bias=False
543
801
  ) # Output projection
544
802
 
@@ -557,15 +815,15 @@ class MultiHeadSelfAttention(nn.Module):
557
815
  # x: [Batch, Length, Dim]
558
816
  B, L, D = x.shape
559
817
 
560
- Q = self.W_Q(x)
561
- K = self.W_K(x)
562
- V = self.W_V(x)
818
+ q = self.q_proj(x)
819
+ k = self.k_proj(x)
820
+ v = self.v_proj(x)
563
821
 
564
- Q = Q.view(B, L, self.num_heads, self.head_dim).transpose(
822
+ q = q.view(B, L, self.num_heads, self.head_dim).transpose(
565
823
  1, 2
566
824
  ) # [Batch, Heads, Length, head_dim]
567
- K = K.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
568
- V = V.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
825
+ k = k.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
826
+ v = v.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
569
827
 
570
828
  key_padding_mask = None
571
829
  if attention_mask is not None:
@@ -582,22 +840,22 @@ class MultiHeadSelfAttention(nn.Module):
582
840
 
583
841
  if self.use_flash_attention:
584
842
  attn = F.scaled_dot_product_attention(
585
- Q,
586
- K,
587
- V,
843
+ q,
844
+ k,
845
+ v,
588
846
  attn_mask=attn_mask,
589
847
  dropout_p=self.dropout_rate if self.training else 0.0,
590
848
  ) # [B,H,L,dh]
591
849
  else:
592
- scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim**0.5)
850
+ scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5)
593
851
  if attn_mask is not None:
594
852
  scores = scores.masked_fill(attn_mask, float("-inf"))
595
853
  attn_weights = torch.softmax(scores, dim=-1)
596
854
  attn_weights = self.dropout(attn_weights)
597
- attn = torch.matmul(attn_weights, V) # [B,H,L,dh]
855
+ attn = torch.matmul(attn_weights, v) # [B,H,L,dh]
598
856
 
599
857
  attn = attn.transpose(1, 2).contiguous().view(B, L, D)
600
- out = self.W_O(attn)
858
+ out = self.out_proj(attn)
601
859
 
602
860
  if self.use_residual:
603
861
  out = out + x
@@ -620,7 +878,30 @@ class AttentionPoolingLayer(nn.Module):
620
878
  self,
621
879
  embedding_dim: int,
622
880
  hidden_units: list = [80, 40],
623
- activation: str = "sigmoid",
881
+ activation: Literal[
882
+ "dice",
883
+ "relu",
884
+ "relu6",
885
+ "elu",
886
+ "selu",
887
+ "leaky_relu",
888
+ "prelu",
889
+ "gelu",
890
+ "sigmoid",
891
+ "tanh",
892
+ "softplus",
893
+ "softsign",
894
+ "hardswish",
895
+ "mish",
896
+ "silu",
897
+ "swish",
898
+ "hardsigmoid",
899
+ "tanhshrink",
900
+ "softshrink",
901
+ "none",
902
+ "linear",
903
+ "identity",
904
+ ] = "sigmoid",
624
905
  use_softmax: bool = False,
625
906
  ):
626
907
  super().__init__()