nextrec 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. nextrec/__init__.py +1 -1
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -5
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/features.py +30 -22
  6. nextrec/basic/layers.py +220 -106
  7. nextrec/basic/loggers.py +62 -43
  8. nextrec/basic/metrics.py +268 -119
  9. nextrec/basic/model.py +1082 -400
  10. nextrec/basic/session.py +10 -3
  11. nextrec/cli.py +498 -0
  12. nextrec/data/__init__.py +19 -25
  13. nextrec/data/batch_utils.py +11 -3
  14. nextrec/data/data_processing.py +51 -45
  15. nextrec/data/data_utils.py +26 -15
  16. nextrec/data/dataloader.py +272 -95
  17. nextrec/data/preprocessor.py +320 -199
  18. nextrec/loss/listwise.py +17 -9
  19. nextrec/loss/loss_utils.py +7 -8
  20. nextrec/loss/pairwise.py +2 -0
  21. nextrec/loss/pointwise.py +30 -12
  22. nextrec/models/generative/hstu.py +103 -38
  23. nextrec/models/match/dssm.py +82 -68
  24. nextrec/models/match/dssm_v2.py +72 -57
  25. nextrec/models/match/mind.py +175 -107
  26. nextrec/models/match/sdm.py +104 -87
  27. nextrec/models/match/youtube_dnn.py +73 -59
  28. nextrec/models/multi_task/esmm.py +53 -37
  29. nextrec/models/multi_task/mmoe.py +64 -45
  30. nextrec/models/multi_task/ple.py +101 -48
  31. nextrec/models/multi_task/poso.py +113 -36
  32. nextrec/models/multi_task/share_bottom.py +48 -35
  33. nextrec/models/ranking/afm.py +72 -37
  34. nextrec/models/ranking/autoint.py +72 -55
  35. nextrec/models/ranking/dcn.py +55 -35
  36. nextrec/models/ranking/dcn_v2.py +64 -23
  37. nextrec/models/ranking/deepfm.py +32 -22
  38. nextrec/models/ranking/dien.py +155 -99
  39. nextrec/models/ranking/din.py +85 -57
  40. nextrec/models/ranking/fibinet.py +52 -32
  41. nextrec/models/ranking/fm.py +29 -23
  42. nextrec/models/ranking/masknet.py +91 -29
  43. nextrec/models/ranking/pnn.py +31 -28
  44. nextrec/models/ranking/widedeep.py +34 -26
  45. nextrec/models/ranking/xdeepfm.py +60 -38
  46. nextrec/utils/__init__.py +59 -34
  47. nextrec/utils/config.py +490 -0
  48. nextrec/utils/device.py +30 -20
  49. nextrec/utils/distributed.py +36 -9
  50. nextrec/utils/embedding.py +1 -0
  51. nextrec/utils/feature.py +1 -0
  52. nextrec/utils/file.py +32 -11
  53. nextrec/utils/initializer.py +61 -16
  54. nextrec/utils/optimizer.py +25 -9
  55. nextrec/utils/synthetic_data.py +283 -165
  56. nextrec/utils/tensor.py +24 -13
  57. {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/METADATA +4 -4
  58. nextrec-0.4.2.dist-info/RECORD +69 -0
  59. nextrec-0.4.2.dist-info/entry_points.txt +2 -0
  60. nextrec-0.4.1.dist-info/RECORD +0 -66
  61. {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
  62. {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/layers.py CHANGED
@@ -5,19 +5,21 @@ Date: create on 27/10/2025
5
5
  Checkpoint: edit on 29/11/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
- from __future__ import annotations
9
8
 
10
- from itertools import combinations
11
- from collections import OrderedDict
9
+ from __future__ import annotations
12
10
 
13
11
  import torch
14
12
  import torch.nn as nn
15
13
  import torch.nn.functional as F
16
14
 
15
+ from itertools import combinations
16
+ from collections import OrderedDict
17
+
17
18
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
18
19
  from nextrec.utils.initializer import get_initializer
19
20
  from nextrec.basic.activation import activation_layer
20
21
 
22
+
21
23
  class PredictionLayer(nn.Module):
22
24
  def __init__(
23
25
  self,
@@ -30,7 +32,7 @@ class PredictionLayer(nn.Module):
30
32
  self.task_types = [task_type] if isinstance(task_type, str) else list(task_type)
31
33
  if len(self.task_types) == 0:
32
34
  raise ValueError("At least one task_type must be specified.")
33
-
35
+
34
36
  if task_dims is None:
35
37
  dims = [1] * len(self.task_types)
36
38
  elif isinstance(task_dims, int):
@@ -38,7 +40,9 @@ class PredictionLayer(nn.Module):
38
40
  else:
39
41
  dims = list(task_dims)
40
42
  if len(dims) not in (1, len(self.task_types)):
41
- raise ValueError("[PredictionLayer Error]: task_dims must be None, a single int (shared), or a sequence of the same length as task_type.")
43
+ raise ValueError(
44
+ "[PredictionLayer Error]: task_dims must be None, a single int (shared), or a sequence of the same length as task_type."
45
+ )
42
46
  if len(dims) == 1 and len(self.task_types) > 1:
43
47
  dims = dims * len(self.task_types)
44
48
  self.task_dims = dims
@@ -62,27 +66,33 @@ class PredictionLayer(nn.Module):
62
66
  if x.dim() == 1:
63
67
  x = x.unsqueeze(0) # (1 * total_dim)
64
68
  if x.shape[-1] != self.total_dim:
65
- raise ValueError(f"[PredictionLayer Error]: Input last dimension ({x.shape[-1]}) does not match expected total dimension ({self.total_dim}).")
69
+ raise ValueError(
70
+ f"[PredictionLayer Error]: Input last dimension ({x.shape[-1]}) does not match expected total dimension ({self.total_dim})."
71
+ )
66
72
  logits = x if self.bias is None else x + self.bias
67
73
  outputs = []
68
74
  for task_type, (start, end) in zip(self.task_types, self._task_slices):
69
- task_logits = logits[..., start:end] # logits for the current task
75
+ task_logits = logits[..., start:end] # logits for the current task
70
76
  if self.return_logits:
71
77
  outputs.append(task_logits)
72
78
  continue
73
79
  task = task_type.lower()
74
- if task == 'binary':
75
- activation = torch.sigmoid
76
- elif task == 'regression':
77
- activation = lambda x: x
78
- elif task == 'multiclass':
79
- activation = lambda x: torch.softmax(x, dim=-1)
80
+ if task == "binary":
81
+ outputs.append(torch.sigmoid(task_logits))
82
+ elif task == "regression":
83
+ outputs.append(task_logits)
84
+ elif task == "multiclass":
85
+ outputs.append(torch.softmax(task_logits, dim=-1))
80
86
  else:
81
- raise ValueError(f"[PredictionLayer Error]: Unsupported task_type '{task_type}'.")
82
- outputs.append(activation(task_logits))
83
- result = torch.cat(outputs, dim=-1) # single: (N,1), multi-task/multi-class: (N,total_dim)
87
+ raise ValueError(
88
+ f"[PredictionLayer Error]: Unsupported task_type '{task_type}'."
89
+ )
90
+ result = torch.cat(
91
+ outputs, dim=-1
92
+ ) # single: (N,1), multi-task/multi-class: (N,total_dim)
84
93
  return result
85
94
 
95
+
86
96
  class EmbeddingLayer(nn.Module):
87
97
  def __init__(self, features: list):
88
98
  super().__init__()
@@ -96,20 +106,30 @@ class EmbeddingLayer(nn.Module):
96
106
  if feature.embedding_name in self.embed_dict:
97
107
  continue
98
108
  if getattr(feature, "pretrained_weight", None) is not None:
99
- weight = feature.pretrained_weight # type: ignore[assignment]
100
- if weight.shape != (feature.vocab_size, feature.embedding_dim): # type: ignore[assignment]
101
- raise ValueError(f"[EmbeddingLayer Error]: Pretrained weight for '{feature.embedding_name}' has shape {weight.shape}, expected ({feature.vocab_size}, {feature.embedding_dim}).") # type: ignore[assignment]
102
- embedding = nn.Embedding.from_pretrained(embeddings=weight, freeze=feature.freeze_pretrained, padding_idx=feature.padding_idx) # type: ignore[assignment]
103
- embedding.weight.requires_grad = feature.trainable and not feature.freeze_pretrained # type: ignore[assignment]
109
+ weight = feature.pretrained_weight # type: ignore[assignment]
110
+ if weight.shape != (feature.vocab_size, feature.embedding_dim): # type: ignore[assignment]
111
+ raise ValueError(f"[EmbeddingLayer Error]: Pretrained weight for '{feature.embedding_name}' has shape {weight.shape}, expected ({feature.vocab_size}, {feature.embedding_dim}).") # type: ignore[assignment]
112
+ embedding = nn.Embedding.from_pretrained(embeddings=weight, freeze=feature.freeze_pretrained, padding_idx=feature.padding_idx) # type: ignore[assignment]
113
+ embedding.weight.requires_grad = feature.trainable and not feature.freeze_pretrained # type: ignore[assignment]
104
114
  else:
105
- embedding = nn.Embedding(num_embeddings=feature.vocab_size, embedding_dim=feature.embedding_dim, padding_idx=feature.padding_idx)
115
+ embedding = nn.Embedding(
116
+ num_embeddings=feature.vocab_size,
117
+ embedding_dim=feature.embedding_dim,
118
+ padding_idx=feature.padding_idx,
119
+ )
106
120
  embedding.weight.requires_grad = feature.trainable
107
- initialization = get_initializer(init_type=feature.init_type, activation="linear", param=feature.init_params)
121
+ initialization = get_initializer(
122
+ init_type=feature.init_type,
123
+ activation="linear",
124
+ param=feature.init_params,
125
+ )
108
126
  initialization(embedding.weight)
109
127
  self.embed_dict[feature.embedding_name] = embedding
110
128
  elif isinstance(feature, DenseFeature):
111
129
  if not feature.use_embedding:
112
- self.dense_input_dims[feature.name] = max(int(getattr(feature, "input_dim", 1)), 1)
130
+ self.dense_input_dims[feature.name] = max(
131
+ int(getattr(feature, "input_dim", 1)), 1
132
+ )
113
133
  continue
114
134
  if feature.name in self.dense_transforms:
115
135
  continue
@@ -121,7 +141,9 @@ class EmbeddingLayer(nn.Module):
121
141
  self.dense_transforms[feature.name] = dense_linear
122
142
  self.dense_input_dims[feature.name] = in_dim
123
143
  else:
124
- raise TypeError(f"[EmbeddingLayer Error]: Unsupported feature type: {type(feature)}")
144
+ raise TypeError(
145
+ f"[EmbeddingLayer Error]: Unsupported feature type: {type(feature)}"
146
+ )
125
147
  self.output_dim = self.compute_output_dim()
126
148
 
127
149
  def forward(
@@ -153,7 +175,9 @@ class EmbeddingLayer(nn.Module):
153
175
  elif feature.combiner == "concat":
154
176
  pooling_layer = ConcatPooling()
155
177
  else:
156
- raise ValueError(f"[EmbeddingLayer Error]: Unknown combiner for {feature.name}: {feature.combiner}")
178
+ raise ValueError(
179
+ f"[EmbeddingLayer Error]: Unknown combiner for {feature.name}: {feature.combiner}"
180
+ )
157
181
  feature_mask = InputMask()(x, feature, seq_input)
158
182
  sparse_embeds.append(pooling_layer(seq_emb, feature_mask).unsqueeze(1))
159
183
 
@@ -168,9 +192,11 @@ class EmbeddingLayer(nn.Module):
168
192
  if dense_embeds:
169
193
  pieces.append(torch.cat(dense_embeds, dim=1))
170
194
  if not pieces:
171
- raise ValueError("[EmbeddingLayer Error]: No input features found for EmbeddingLayer.")
195
+ raise ValueError(
196
+ "[EmbeddingLayer Error]: No input features found for EmbeddingLayer."
197
+ )
172
198
  return pieces[0] if len(pieces) == 1 else torch.cat(pieces, dim=1)
173
-
199
+
174
200
  # squeeze_dim=False requires embeddings with identical last dimension
175
201
  output_embeddings = list(sparse_embeds)
176
202
  if dense_embeds:
@@ -178,36 +204,53 @@ class EmbeddingLayer(nn.Module):
178
204
  target_dim = output_embeddings[0].shape[-1]
179
205
  for emb in dense_embeds:
180
206
  if emb.shape[-1] != target_dim:
181
- raise ValueError(f"[EmbeddingLayer Error]: squeeze_dim=False requires all dense feature dimensions to match the embedding dimension of sparse/sequence features ({target_dim}), but got {emb.shape[-1]}.")
207
+ raise ValueError(
208
+ f"[EmbeddingLayer Error]: squeeze_dim=False requires all dense feature dimensions to match the embedding dimension of sparse/sequence features ({target_dim}), but got {emb.shape[-1]}."
209
+ )
182
210
  output_embeddings.extend(emb.unsqueeze(1) for emb in dense_embeds)
183
211
  else:
184
212
  dims = {emb.shape[-1] for emb in dense_embeds}
185
213
  if len(dims) != 1:
186
- raise ValueError(f"[EmbeddingLayer Error]: squeeze_dim=False requires all dense features to have identical dimensions when no sparse/sequence features are present, but got dimensions {dims}.")
214
+ raise ValueError(
215
+ f"[EmbeddingLayer Error]: squeeze_dim=False requires all dense features to have identical dimensions when no sparse/sequence features are present, but got dimensions {dims}."
216
+ )
187
217
  output_embeddings = [emb.unsqueeze(1) for emb in dense_embeds]
188
218
  if not output_embeddings:
189
- raise ValueError("[EmbeddingLayer Error]: squeeze_dim=False requires at least one sparse/sequence feature or dense features with identical projected dimensions.")
219
+ raise ValueError(
220
+ "[EmbeddingLayer Error]: squeeze_dim=False requires at least one sparse/sequence feature or dense features with identical projected dimensions."
221
+ )
190
222
  return torch.cat(output_embeddings, dim=1)
191
223
 
192
- def project_dense(self, feature: DenseFeature, x: dict[str, torch.Tensor]) -> torch.Tensor:
224
+ def project_dense(
225
+ self, feature: DenseFeature, x: dict[str, torch.Tensor]
226
+ ) -> torch.Tensor:
193
227
  if feature.name not in x:
194
- raise KeyError(f"[EmbeddingLayer Error]:Dense feature '{feature.name}' is missing from input.")
228
+ raise KeyError(
229
+ f"[EmbeddingLayer Error]:Dense feature '{feature.name}' is missing from input."
230
+ )
195
231
  value = x[feature.name].float()
196
232
  if value.dim() == 1:
197
233
  value = value.unsqueeze(-1)
198
234
  else:
199
235
  value = value.view(value.size(0), -1)
200
- expected_in_dim = self.dense_input_dims.get(feature.name, max(int(getattr(feature, "input_dim", 1)), 1))
236
+ expected_in_dim = self.dense_input_dims.get(
237
+ feature.name, max(int(getattr(feature, "input_dim", 1)), 1)
238
+ )
201
239
  if value.shape[1] != expected_in_dim:
202
- raise ValueError(f"[EmbeddingLayer Error]:Dense feature '{feature.name}' expects {expected_in_dim} inputs but got {value.shape[1]}.")
240
+ raise ValueError(
241
+ f"[EmbeddingLayer Error]:Dense feature '{feature.name}' expects {expected_in_dim} inputs but got {value.shape[1]}."
242
+ )
203
243
  if not feature.use_embedding:
204
- return value
244
+ return value
205
245
  dense_layer = self.dense_transforms[feature.name]
206
246
  return dense_layer(value)
207
247
 
208
- def compute_output_dim(self, features: list[DenseFeature | SequenceFeature | SparseFeature] | None = None) -> int:
248
+ def compute_output_dim(
249
+ self,
250
+ features: list[DenseFeature | SequenceFeature | SparseFeature] | None = None,
251
+ ) -> int:
209
252
  candidates = list(features) if features is not None else self.features
210
- unique_feats = OrderedDict((feat.name, feat) for feat in candidates) # type: ignore[assignment]
253
+ unique_feats = OrderedDict((feat.name, feat) for feat in candidates) # type: ignore[assignment]
211
254
  dim = 0
212
255
  for feat in unique_feats.values():
213
256
  if isinstance(feat, DenseFeature):
@@ -218,28 +261,34 @@ class EmbeddingLayer(nn.Module):
218
261
  elif isinstance(feat, SequenceFeature) and feat.combiner == "concat":
219
262
  dim += feat.embedding_dim * feat.max_len
220
263
  else:
221
- dim += feat.embedding_dim # type: ignore[assignment]
264
+ dim += feat.embedding_dim # type: ignore[assignment]
222
265
  return dim
223
266
 
224
267
  def get_input_dim(self, features: list[object] | None = None) -> int:
225
- return self.compute_output_dim(features) # type: ignore[assignment]
268
+ return self.compute_output_dim(features) # type: ignore[assignment]
226
269
 
227
270
  @property
228
271
  def input_dim(self) -> int:
229
272
  return self.output_dim
230
273
 
274
+
231
275
  class InputMask(nn.Module):
232
276
  def __init__(self):
233
277
  super().__init__()
234
278
 
235
- def forward(self, x: dict[str, torch.Tensor], feature: SequenceFeature, seq_tensor: torch.Tensor | None = None):
279
+ def forward(
280
+ self,
281
+ x: dict[str, torch.Tensor],
282
+ feature: SequenceFeature,
283
+ seq_tensor: torch.Tensor | None = None,
284
+ ):
236
285
  if seq_tensor is not None:
237
286
  values = seq_tensor
238
287
  else:
239
288
  values = x[feature.name]
240
289
  values = values.long()
241
290
  padding_idx = feature.padding_idx if feature.padding_idx is not None else 0
242
- mask = (values != padding_idx)
291
+ mask = values != padding_idx
243
292
 
244
293
  if mask.dim() == 1:
245
294
  # [B] -> [B, 1, 1]
@@ -253,14 +302,14 @@ class InputMask(nn.Module):
253
302
  if mask.size(1) != 1 and mask.size(2) == 1:
254
303
  mask = mask.squeeze(-1).unsqueeze(1)
255
304
  else:
256
- raise ValueError(f"InputMask only supports 1D/2D/3D tensors, got shape {values.shape}")
305
+ raise ValueError(
306
+ f"InputMask only supports 1D/2D/3D tensors, got shape {values.shape}"
307
+ )
257
308
  return mask.float()
258
309
 
310
+
259
311
  class LR(nn.Module):
260
- def __init__(
261
- self,
262
- input_dim: int,
263
- sigmoid: bool = False):
312
+ def __init__(self, input_dim: int, sigmoid: bool = False):
264
313
  super().__init__()
265
314
  self.sigmoid = sigmoid
266
315
  self.fc = nn.Linear(input_dim, 1, bias=True)
@@ -271,18 +320,24 @@ class LR(nn.Module):
271
320
  else:
272
321
  return self.fc(x)
273
322
 
323
+
274
324
  class ConcatPooling(nn.Module):
275
325
  def __init__(self):
276
326
  super().__init__()
277
327
 
278
- def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
279
- return x.flatten(start_dim=1, end_dim=2)
328
+ def forward(
329
+ self, x: torch.Tensor, mask: torch.Tensor | None = None
330
+ ) -> torch.Tensor:
331
+ return x.flatten(start_dim=1, end_dim=2)
332
+
280
333
 
281
334
  class AveragePooling(nn.Module):
282
335
  def __init__(self):
283
336
  super().__init__()
284
337
 
285
- def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
338
+ def forward(
339
+ self, x: torch.Tensor, mask: torch.Tensor | None = None
340
+ ) -> torch.Tensor:
286
341
  if mask is None:
287
342
  return torch.mean(x, dim=1)
288
343
  else:
@@ -290,24 +345,29 @@ class AveragePooling(nn.Module):
290
345
  non_padding_length = mask.sum(dim=-1)
291
346
  return sum_pooling_matrix / (non_padding_length.float() + 1e-16)
292
347
 
348
+
293
349
  class SumPooling(nn.Module):
294
350
  def __init__(self):
295
351
  super().__init__()
296
352
 
297
- def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
353
+ def forward(
354
+ self, x: torch.Tensor, mask: torch.Tensor | None = None
355
+ ) -> torch.Tensor:
298
356
  if mask is None:
299
357
  return torch.sum(x, dim=1)
300
358
  else:
301
359
  return torch.bmm(mask, x).squeeze(1)
302
360
 
361
+
303
362
  class MLP(nn.Module):
304
363
  def __init__(
305
- self,
306
- input_dim: int,
307
- output_layer: bool = True,
308
- dims: list[int] | None = None,
309
- dropout: float = 0.0,
310
- activation: str = "relu"):
364
+ self,
365
+ input_dim: int,
366
+ output_layer: bool = True,
367
+ dims: list[int] | None = None,
368
+ dropout: float = 0.0,
369
+ activation: str = "relu",
370
+ ):
311
371
  super().__init__()
312
372
  if dims is None:
313
373
  dims = []
@@ -320,29 +380,32 @@ class MLP(nn.Module):
320
380
  layers.append(activation_layer(activation))
321
381
  layers.append(nn.Dropout(p=dropout))
322
382
  current_dim = i_dim
323
-
383
+
324
384
  if output_layer:
325
385
  layers.append(nn.Linear(current_dim, 1))
326
386
  self.output_dim = 1
327
387
  else:
328
- self.output_dim = current_dim
388
+ self.output_dim = current_dim
329
389
  self.mlp = nn.Sequential(*layers)
390
+
330
391
  def forward(self, x):
331
392
  return self.mlp(x)
332
-
393
+
394
+
333
395
  class FM(nn.Module):
334
396
  def __init__(self, reduce_sum: bool = True):
335
397
  super().__init__()
336
398
  self.reduce_sum = reduce_sum
337
399
 
338
400
  def forward(self, x: torch.Tensor) -> torch.Tensor:
339
- square_of_sum = torch.sum(x, dim=1)**2
401
+ square_of_sum = torch.sum(x, dim=1) ** 2
340
402
  sum_of_square = torch.sum(x**2, dim=1)
341
403
  ix = square_of_sum - sum_of_square
342
404
  if self.reduce_sum:
343
405
  ix = torch.sum(ix, dim=1, keepdim=True)
344
406
  return 0.5 * ix
345
407
 
408
+
346
409
  class CrossLayer(nn.Module):
347
410
  def __init__(self, input_dim: int):
348
411
  super(CrossLayer, self).__init__()
@@ -353,60 +416,74 @@ class CrossLayer(nn.Module):
353
416
  x = self.w(x_i) * x_0 + self.b
354
417
  return x
355
418
 
419
+
356
420
  class SENETLayer(nn.Module):
357
- def __init__(
358
- self,
359
- num_fields: int,
360
- reduction_ratio: int = 3):
421
+ def __init__(self, num_fields: int, reduction_ratio: int = 3):
361
422
  super(SENETLayer, self).__init__()
362
- reduced_size = max(1, int(num_fields/ reduction_ratio))
363
- self.mlp = nn.Sequential(nn.Linear(num_fields, reduced_size, bias=False),
364
- nn.ReLU(),
365
- nn.Linear(reduced_size, num_fields, bias=False),
366
- nn.ReLU())
423
+ reduced_size = max(1, int(num_fields / reduction_ratio))
424
+ self.mlp = nn.Sequential(
425
+ nn.Linear(num_fields, reduced_size, bias=False),
426
+ nn.ReLU(),
427
+ nn.Linear(reduced_size, num_fields, bias=False),
428
+ nn.ReLU(),
429
+ )
430
+
367
431
  def forward(self, x: torch.Tensor) -> torch.Tensor:
368
432
  z = torch.mean(x, dim=-1, out=None)
369
433
  a = self.mlp(z)
370
- v = x*a.unsqueeze(-1)
434
+ v = x * a.unsqueeze(-1)
371
435
  return v
372
436
 
437
+
373
438
  class BiLinearInteractionLayer(nn.Module):
374
439
  def __init__(
375
- self,
376
- input_dim: int,
377
- num_fields: int,
378
- bilinear_type: str = "field_interaction"):
440
+ self, input_dim: int, num_fields: int, bilinear_type: str = "field_interaction"
441
+ ):
379
442
  super(BiLinearInteractionLayer, self).__init__()
380
443
  self.bilinear_type = bilinear_type
381
444
  if self.bilinear_type == "field_all":
382
445
  self.bilinear_layer = nn.Linear(input_dim, input_dim, bias=False)
383
446
  elif self.bilinear_type == "field_each":
384
- self.bilinear_layer = nn.ModuleList([nn.Linear(input_dim, input_dim, bias=False) for i in range(num_fields)])
447
+ self.bilinear_layer = nn.ModuleList(
448
+ [nn.Linear(input_dim, input_dim, bias=False) for i in range(num_fields)]
449
+ )
385
450
  elif self.bilinear_type == "field_interaction":
386
- self.bilinear_layer = nn.ModuleList([nn.Linear(input_dim, input_dim, bias=False) for i,j in combinations(range(num_fields), 2)])
451
+ self.bilinear_layer = nn.ModuleList(
452
+ [
453
+ nn.Linear(input_dim, input_dim, bias=False)
454
+ for i, j in combinations(range(num_fields), 2)
455
+ ]
456
+ )
387
457
  else:
388
458
  raise NotImplementedError()
389
459
 
390
460
  def forward(self, x: torch.Tensor) -> torch.Tensor:
391
461
  feature_emb = torch.split(x, 1, dim=1)
392
462
  if self.bilinear_type == "field_all":
393
- bilinear_list = [self.bilinear_layer(v_i)*v_j for v_i, v_j in combinations(feature_emb, 2)]
463
+ bilinear_list = [
464
+ self.bilinear_layer(v_i) * v_j
465
+ for v_i, v_j in combinations(feature_emb, 2)
466
+ ]
394
467
  elif self.bilinear_type == "field_each":
395
- bilinear_list = [self.bilinear_layer[i](feature_emb[i])*feature_emb[j] for i,j in combinations(range(len(feature_emb)), 2)] # type: ignore[assignment]
468
+ bilinear_list = [self.bilinear_layer[i](feature_emb[i]) * feature_emb[j] for i, j in combinations(range(len(feature_emb)), 2)] # type: ignore[assignment]
396
469
  elif self.bilinear_type == "field_interaction":
397
- bilinear_list = [self.bilinear_layer[i](v[0])*v[1] for i,v in enumerate(combinations(feature_emb, 2))] # type: ignore[assignment]
470
+ bilinear_list = [self.bilinear_layer[i](v[0]) * v[1] for i, v in enumerate(combinations(feature_emb, 2))] # type: ignore[assignment]
398
471
  return torch.cat(bilinear_list, dim=1)
399
472
 
473
+
400
474
  class MultiHeadSelfAttention(nn.Module):
401
475
  def __init__(
402
- self,
403
- embedding_dim: int,
404
- num_heads: int = 2,
405
- dropout: float = 0.0,
406
- use_residual: bool = True):
476
+ self,
477
+ embedding_dim: int,
478
+ num_heads: int = 2,
479
+ dropout: float = 0.0,
480
+ use_residual: bool = True,
481
+ ):
407
482
  super().__init__()
408
483
  if embedding_dim % num_heads != 0:
409
- raise ValueError(f"[MultiHeadSelfAttention Error]: embedding_dim ({embedding_dim}) must be divisible by num_heads ({num_heads})")
484
+ raise ValueError(
485
+ f"[MultiHeadSelfAttention Error]: embedding_dim ({embedding_dim}) must be divisible by num_heads ({num_heads})"
486
+ )
410
487
  self.embedding_dim = embedding_dim
411
488
  self.num_heads = num_heads
412
489
  self.head_dim = embedding_dim // num_heads
@@ -417,24 +494,34 @@ class MultiHeadSelfAttention(nn.Module):
417
494
  if self.use_residual:
418
495
  self.W_Res = nn.Linear(embedding_dim, embedding_dim, bias=False)
419
496
  self.dropout = nn.Dropout(dropout)
420
-
497
+
421
498
  def forward(self, x: torch.Tensor) -> torch.Tensor:
422
499
  batch_size, num_fields, _ = x.shape
423
500
  Q = self.W_Q(x) # [batch_size, num_fields, embedding_dim]
424
501
  K = self.W_K(x)
425
502
  V = self.W_V(x)
426
503
  # Split into multiple heads: [batch_size, num_heads, num_fields, head_dim]
427
- Q = Q.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
428
- K = K.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
429
- V = V.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
504
+ Q = Q.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
505
+ 1, 2
506
+ )
507
+ K = K.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
508
+ 1, 2
509
+ )
510
+ V = V.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
511
+ 1, 2
512
+ )
430
513
  # Attention scores
431
- scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
514
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim**0.5)
432
515
  attention_weights = F.softmax(scores, dim=-1)
433
516
  attention_weights = self.dropout(attention_weights)
434
- attention_output = torch.matmul(attention_weights, V) # [batch_size, num_heads, num_fields, head_dim]
517
+ attention_output = torch.matmul(
518
+ attention_weights, V
519
+ ) # [batch_size, num_heads, num_fields, head_dim]
435
520
  # Concatenate heads
436
521
  attention_output = attention_output.transpose(1, 2).contiguous()
437
- attention_output = attention_output.view(batch_size, num_fields, self.embedding_dim)
522
+ attention_output = attention_output.view(
523
+ batch_size, num_fields, self.embedding_dim
524
+ )
438
525
  # Residual connection
439
526
  if self.use_residual:
440
527
  output = attention_output + self.W_Res(x)
@@ -443,17 +530,20 @@ class MultiHeadSelfAttention(nn.Module):
443
530
  output = F.relu(output)
444
531
  return output
445
532
 
533
+
446
534
  class AttentionPoolingLayer(nn.Module):
447
535
  """
448
536
  Attention pooling layer for DIN/DIEN
449
537
  Computes attention weights between query (candidate item) and keys (user behavior sequence)
450
538
  """
539
+
451
540
  def __init__(
452
- self,
453
- embedding_dim: int,
454
- hidden_units: list = [80, 40],
455
- activation: str ='sigmoid',
456
- use_softmax: bool = True):
541
+ self,
542
+ embedding_dim: int,
543
+ hidden_units: list = [80, 40],
544
+ activation: str = "sigmoid",
545
+ use_softmax: bool = True,
546
+ ):
457
547
  super().__init__()
458
548
  self.embedding_dim = embedding_dim
459
549
  self.use_softmax = use_softmax
@@ -467,8 +557,14 @@ class AttentionPoolingLayer(nn.Module):
467
557
  input_dim = hidden_unit
468
558
  layers.append(nn.Linear(input_dim, 1))
469
559
  self.attention_net = nn.Sequential(*layers)
470
-
471
- def forward(self, query: torch.Tensor, keys: torch.Tensor, keys_length: torch.Tensor | None = None, mask: torch.Tensor | None = None):
560
+
561
+ def forward(
562
+ self,
563
+ query: torch.Tensor,
564
+ keys: torch.Tensor,
565
+ keys_length: torch.Tensor | None = None,
566
+ mask: torch.Tensor | None = None,
567
+ ):
472
568
  """
473
569
  Args:
474
570
  query: [batch_size, embedding_dim] - candidate item embedding
@@ -479,28 +575,46 @@ class AttentionPoolingLayer(nn.Module):
479
575
  output: [batch_size, embedding_dim] - attention pooled representation
480
576
  """
481
577
  batch_size, sequence_length, embedding_dim = keys.shape
482
- assert query.shape == (batch_size, embedding_dim), f"query shape {query.shape} != ({batch_size}, {embedding_dim})"
578
+ assert query.shape == (
579
+ batch_size,
580
+ embedding_dim,
581
+ ), f"query shape {query.shape} != ({batch_size}, {embedding_dim})"
483
582
  if mask is None and keys_length is not None:
484
583
  # keys_length: (batch_size,)
485
584
  device = keys.device
486
- seq_range = torch.arange(sequence_length, device=device).unsqueeze(0) # (1, sequence_length)
585
+ seq_range = torch.arange(sequence_length, device=device).unsqueeze(
586
+ 0
587
+ ) # (1, sequence_length)
487
588
  mask = (seq_range < keys_length.unsqueeze(1)).unsqueeze(-1).float()
488
589
  if mask is not None:
489
590
  if mask.dim() == 2:
490
591
  # (B, L)
491
592
  mask = mask.unsqueeze(-1)
492
- elif mask.dim() == 3 and mask.shape[1] == 1 and mask.shape[2] == sequence_length:
593
+ elif (
594
+ mask.dim() == 3
595
+ and mask.shape[1] == 1
596
+ and mask.shape[2] == sequence_length
597
+ ):
493
598
  # (B, 1, L) -> (B, L, 1)
494
599
  mask = mask.transpose(1, 2)
495
- elif mask.dim() == 3 and mask.shape[1] == sequence_length and mask.shape[2] == 1:
600
+ elif (
601
+ mask.dim() == 3
602
+ and mask.shape[1] == sequence_length
603
+ and mask.shape[2] == 1
604
+ ):
496
605
  pass
497
606
  else:
498
- raise ValueError(f"[AttentionPoolingLayer Error]: Unsupported mask shape: {mask.shape}")
607
+ raise ValueError(
608
+ f"[AttentionPoolingLayer Error]: Unsupported mask shape: {mask.shape}"
609
+ )
499
610
  mask = mask.to(keys.dtype)
500
611
  # Expand query to (B, L, D)
501
612
  query_expanded = query.unsqueeze(1).expand(-1, sequence_length, -1)
502
613
  # [query, key, query-key, query*key] -> (B, L, 4D)
503
- attention_input = torch.cat([query_expanded, keys, query_expanded - keys, query_expanded * keys], dim=-1,)
614
+ attention_input = torch.cat(
615
+ [query_expanded, keys, query_expanded - keys, query_expanded * keys],
616
+ dim=-1,
617
+ )
504
618
  attention_scores = self.attention_net(attention_input)
505
619
  if mask is not None:
506
620
  attention_scores = attention_scores.masked_fill(mask == 0, -1e9)