nextrec 0.3.6__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. nextrec/__init__.py +1 -1
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -5
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/features.py +30 -22
  6. nextrec/basic/layers.py +244 -113
  7. nextrec/basic/loggers.py +62 -43
  8. nextrec/basic/metrics.py +268 -119
  9. nextrec/basic/model.py +1373 -443
  10. nextrec/basic/session.py +10 -3
  11. nextrec/cli.py +498 -0
  12. nextrec/data/__init__.py +19 -25
  13. nextrec/data/batch_utils.py +11 -3
  14. nextrec/data/data_processing.py +42 -24
  15. nextrec/data/data_utils.py +26 -15
  16. nextrec/data/dataloader.py +303 -96
  17. nextrec/data/preprocessor.py +320 -199
  18. nextrec/loss/listwise.py +17 -9
  19. nextrec/loss/loss_utils.py +7 -8
  20. nextrec/loss/pairwise.py +2 -0
  21. nextrec/loss/pointwise.py +30 -12
  22. nextrec/models/generative/hstu.py +106 -40
  23. nextrec/models/match/dssm.py +82 -69
  24. nextrec/models/match/dssm_v2.py +72 -58
  25. nextrec/models/match/mind.py +175 -108
  26. nextrec/models/match/sdm.py +104 -88
  27. nextrec/models/match/youtube_dnn.py +73 -60
  28. nextrec/models/multi_task/esmm.py +53 -39
  29. nextrec/models/multi_task/mmoe.py +70 -47
  30. nextrec/models/multi_task/ple.py +107 -50
  31. nextrec/models/multi_task/poso.py +121 -41
  32. nextrec/models/multi_task/share_bottom.py +54 -38
  33. nextrec/models/ranking/afm.py +172 -45
  34. nextrec/models/ranking/autoint.py +84 -61
  35. nextrec/models/ranking/dcn.py +59 -42
  36. nextrec/models/ranking/dcn_v2.py +64 -23
  37. nextrec/models/ranking/deepfm.py +36 -26
  38. nextrec/models/ranking/dien.py +158 -102
  39. nextrec/models/ranking/din.py +88 -60
  40. nextrec/models/ranking/fibinet.py +55 -35
  41. nextrec/models/ranking/fm.py +32 -26
  42. nextrec/models/ranking/masknet.py +95 -34
  43. nextrec/models/ranking/pnn.py +34 -31
  44. nextrec/models/ranking/widedeep.py +37 -29
  45. nextrec/models/ranking/xdeepfm.py +63 -41
  46. nextrec/utils/__init__.py +61 -32
  47. nextrec/utils/config.py +490 -0
  48. nextrec/utils/device.py +52 -12
  49. nextrec/utils/distributed.py +141 -0
  50. nextrec/utils/embedding.py +1 -0
  51. nextrec/utils/feature.py +1 -0
  52. nextrec/utils/file.py +32 -11
  53. nextrec/utils/initializer.py +61 -16
  54. nextrec/utils/optimizer.py +25 -9
  55. nextrec/utils/synthetic_data.py +531 -0
  56. nextrec/utils/tensor.py +24 -13
  57. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/METADATA +15 -5
  58. nextrec-0.4.2.dist-info/RECORD +69 -0
  59. nextrec-0.4.2.dist-info/entry_points.txt +2 -0
  60. nextrec-0.3.6.dist-info/RECORD +0 -64
  61. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
  62. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/layers.py CHANGED
@@ -5,19 +5,21 @@ Date: create on 27/10/2025
5
5
  Checkpoint: edit on 29/11/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
- from __future__ import annotations
9
8
 
10
- from itertools import combinations
11
- from collections import OrderedDict
9
+ from __future__ import annotations
12
10
 
13
11
  import torch
14
12
  import torch.nn as nn
15
13
  import torch.nn.functional as F
16
14
 
15
+ from itertools import combinations
16
+ from collections import OrderedDict
17
+
17
18
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
18
19
  from nextrec.utils.initializer import get_initializer
19
20
  from nextrec.basic.activation import activation_layer
20
21
 
22
+
21
23
  class PredictionLayer(nn.Module):
22
24
  def __init__(
23
25
  self,
@@ -30,7 +32,7 @@ class PredictionLayer(nn.Module):
30
32
  self.task_types = [task_type] if isinstance(task_type, str) else list(task_type)
31
33
  if len(self.task_types) == 0:
32
34
  raise ValueError("At least one task_type must be specified.")
33
-
35
+
34
36
  if task_dims is None:
35
37
  dims = [1] * len(self.task_types)
36
38
  elif isinstance(task_dims, int):
@@ -38,7 +40,9 @@ class PredictionLayer(nn.Module):
38
40
  else:
39
41
  dims = list(task_dims)
40
42
  if len(dims) not in (1, len(self.task_types)):
41
- raise ValueError("[PredictionLayer Error]: task_dims must be None, a single int (shared), or a sequence of the same length as task_type.")
43
+ raise ValueError(
44
+ "[PredictionLayer Error]: task_dims must be None, a single int (shared), or a sequence of the same length as task_type."
45
+ )
42
46
  if len(dims) == 1 and len(self.task_types) > 1:
43
47
  dims = dims * len(self.task_types)
44
48
  self.task_dims = dims
@@ -62,29 +66,33 @@ class PredictionLayer(nn.Module):
62
66
  if x.dim() == 1:
63
67
  x = x.unsqueeze(0) # (1 * total_dim)
64
68
  if x.shape[-1] != self.total_dim:
65
- raise ValueError(f"[PredictionLayer Error]: Input last dimension ({x.shape[-1]}) does not match expected total dimension ({self.total_dim}).")
69
+ raise ValueError(
70
+ f"[PredictionLayer Error]: Input last dimension ({x.shape[-1]}) does not match expected total dimension ({self.total_dim})."
71
+ )
66
72
  logits = x if self.bias is None else x + self.bias
67
73
  outputs = []
68
74
  for task_type, (start, end) in zip(self.task_types, self._task_slices):
69
- task_logits = logits[..., start:end] # logits for the current task
75
+ task_logits = logits[..., start:end] # logits for the current task
70
76
  if self.return_logits:
71
77
  outputs.append(task_logits)
72
78
  continue
73
79
  task = task_type.lower()
74
- if task == 'binary':
75
- activation = torch.sigmoid
76
- elif task == 'regression':
77
- activation = lambda x: x
78
- elif task == 'multiclass':
79
- activation = lambda x: torch.softmax(x, dim=-1)
80
+ if task == "binary":
81
+ outputs.append(torch.sigmoid(task_logits))
82
+ elif task == "regression":
83
+ outputs.append(task_logits)
84
+ elif task == "multiclass":
85
+ outputs.append(torch.softmax(task_logits, dim=-1))
80
86
  else:
81
- raise ValueError(f"[PredictionLayer Error]: Unsupported task_type '{task_type}'.")
82
- outputs.append(activation(task_logits))
83
- result = torch.cat(outputs, dim=-1)
84
- if result.shape[-1] == 1:
85
- result = result.squeeze(-1)
87
+ raise ValueError(
88
+ f"[PredictionLayer Error]: Unsupported task_type '{task_type}'."
89
+ )
90
+ result = torch.cat(
91
+ outputs, dim=-1
92
+ ) # single: (N,1), multi-task/multi-class: (N,total_dim)
86
93
  return result
87
94
 
95
+
88
96
  class EmbeddingLayer(nn.Module):
89
97
  def __init__(self, features: list):
90
98
  super().__init__()
@@ -98,20 +106,30 @@ class EmbeddingLayer(nn.Module):
98
106
  if feature.embedding_name in self.embed_dict:
99
107
  continue
100
108
  if getattr(feature, "pretrained_weight", None) is not None:
101
- weight = feature.pretrained_weight # type: ignore[assignment]
102
- if weight.shape != (feature.vocab_size, feature.embedding_dim): # type: ignore[assignment]
103
- raise ValueError(f"[EmbeddingLayer Error]: Pretrained weight for '{feature.embedding_name}' has shape {weight.shape}, expected ({feature.vocab_size}, {feature.embedding_dim}).") # type: ignore[assignment]
104
- embedding = nn.Embedding.from_pretrained(embeddings=weight, freeze=feature.freeze_pretrained, padding_idx=feature.padding_idx) # type: ignore[assignment]
105
- embedding.weight.requires_grad = feature.trainable and not feature.freeze_pretrained # type: ignore[assignment]
109
+ weight = feature.pretrained_weight # type: ignore[assignment]
110
+ if weight.shape != (feature.vocab_size, feature.embedding_dim): # type: ignore[assignment]
111
+ raise ValueError(f"[EmbeddingLayer Error]: Pretrained weight for '{feature.embedding_name}' has shape {weight.shape}, expected ({feature.vocab_size}, {feature.embedding_dim}).") # type: ignore[assignment]
112
+ embedding = nn.Embedding.from_pretrained(embeddings=weight, freeze=feature.freeze_pretrained, padding_idx=feature.padding_idx) # type: ignore[assignment]
113
+ embedding.weight.requires_grad = feature.trainable and not feature.freeze_pretrained # type: ignore[assignment]
106
114
  else:
107
- embedding = nn.Embedding(num_embeddings=feature.vocab_size, embedding_dim=feature.embedding_dim, padding_idx=feature.padding_idx)
115
+ embedding = nn.Embedding(
116
+ num_embeddings=feature.vocab_size,
117
+ embedding_dim=feature.embedding_dim,
118
+ padding_idx=feature.padding_idx,
119
+ )
108
120
  embedding.weight.requires_grad = feature.trainable
109
- initialization = get_initializer(init_type=feature.init_type, activation="linear", param=feature.init_params)
121
+ initialization = get_initializer(
122
+ init_type=feature.init_type,
123
+ activation="linear",
124
+ param=feature.init_params,
125
+ )
110
126
  initialization(embedding.weight)
111
127
  self.embed_dict[feature.embedding_name] = embedding
112
128
  elif isinstance(feature, DenseFeature):
113
129
  if not feature.use_embedding:
114
- self.dense_input_dims[feature.name] = max(int(getattr(feature, "input_dim", 1)), 1)
130
+ self.dense_input_dims[feature.name] = max(
131
+ int(getattr(feature, "input_dim", 1)), 1
132
+ )
115
133
  continue
116
134
  if feature.name in self.dense_transforms:
117
135
  continue
@@ -123,7 +141,9 @@ class EmbeddingLayer(nn.Module):
123
141
  self.dense_transforms[feature.name] = dense_linear
124
142
  self.dense_input_dims[feature.name] = in_dim
125
143
  else:
126
- raise TypeError(f"[EmbeddingLayer Error]: Unsupported feature type: {type(feature)}")
144
+ raise TypeError(
145
+ f"[EmbeddingLayer Error]: Unsupported feature type: {type(feature)}"
146
+ )
127
147
  self.output_dim = self.compute_output_dim()
128
148
 
129
149
  def forward(
@@ -155,7 +175,9 @@ class EmbeddingLayer(nn.Module):
155
175
  elif feature.combiner == "concat":
156
176
  pooling_layer = ConcatPooling()
157
177
  else:
158
- raise ValueError(f"[EmbeddingLayer Error]: Unknown combiner for {feature.name}: {feature.combiner}")
178
+ raise ValueError(
179
+ f"[EmbeddingLayer Error]: Unknown combiner for {feature.name}: {feature.combiner}"
180
+ )
159
181
  feature_mask = InputMask()(x, feature, seq_input)
160
182
  sparse_embeds.append(pooling_layer(seq_emb, feature_mask).unsqueeze(1))
161
183
 
@@ -170,9 +192,11 @@ class EmbeddingLayer(nn.Module):
170
192
  if dense_embeds:
171
193
  pieces.append(torch.cat(dense_embeds, dim=1))
172
194
  if not pieces:
173
- raise ValueError("[EmbeddingLayer Error]: No input features found for EmbeddingLayer.")
195
+ raise ValueError(
196
+ "[EmbeddingLayer Error]: No input features found for EmbeddingLayer."
197
+ )
174
198
  return pieces[0] if len(pieces) == 1 else torch.cat(pieces, dim=1)
175
-
199
+
176
200
  # squeeze_dim=False requires embeddings with identical last dimension
177
201
  output_embeddings = list(sparse_embeds)
178
202
  if dense_embeds:
@@ -180,36 +204,53 @@ class EmbeddingLayer(nn.Module):
180
204
  target_dim = output_embeddings[0].shape[-1]
181
205
  for emb in dense_embeds:
182
206
  if emb.shape[-1] != target_dim:
183
- raise ValueError(f"[EmbeddingLayer Error]: squeeze_dim=False requires all dense feature dimensions to match the embedding dimension of sparse/sequence features ({target_dim}), but got {emb.shape[-1]}.")
207
+ raise ValueError(
208
+ f"[EmbeddingLayer Error]: squeeze_dim=False requires all dense feature dimensions to match the embedding dimension of sparse/sequence features ({target_dim}), but got {emb.shape[-1]}."
209
+ )
184
210
  output_embeddings.extend(emb.unsqueeze(1) for emb in dense_embeds)
185
211
  else:
186
212
  dims = {emb.shape[-1] for emb in dense_embeds}
187
213
  if len(dims) != 1:
188
- raise ValueError(f"[EmbeddingLayer Error]: squeeze_dim=False requires all dense features to have identical dimensions when no sparse/sequence features are present, but got dimensions {dims}.")
214
+ raise ValueError(
215
+ f"[EmbeddingLayer Error]: squeeze_dim=False requires all dense features to have identical dimensions when no sparse/sequence features are present, but got dimensions {dims}."
216
+ )
189
217
  output_embeddings = [emb.unsqueeze(1) for emb in dense_embeds]
190
218
  if not output_embeddings:
191
- raise ValueError("[EmbeddingLayer Error]: squeeze_dim=False requires at least one sparse/sequence feature or dense features with identical projected dimensions.")
219
+ raise ValueError(
220
+ "[EmbeddingLayer Error]: squeeze_dim=False requires at least one sparse/sequence feature or dense features with identical projected dimensions."
221
+ )
192
222
  return torch.cat(output_embeddings, dim=1)
193
223
 
194
- def project_dense(self, feature: DenseFeature, x: dict[str, torch.Tensor]) -> torch.Tensor:
224
+ def project_dense(
225
+ self, feature: DenseFeature, x: dict[str, torch.Tensor]
226
+ ) -> torch.Tensor:
195
227
  if feature.name not in x:
196
- raise KeyError(f"[EmbeddingLayer Error]:Dense feature '{feature.name}' is missing from input.")
228
+ raise KeyError(
229
+ f"[EmbeddingLayer Error]:Dense feature '{feature.name}' is missing from input."
230
+ )
197
231
  value = x[feature.name].float()
198
232
  if value.dim() == 1:
199
233
  value = value.unsqueeze(-1)
200
234
  else:
201
235
  value = value.view(value.size(0), -1)
202
- expected_in_dim = self.dense_input_dims.get(feature.name, max(int(getattr(feature, "input_dim", 1)), 1))
236
+ expected_in_dim = self.dense_input_dims.get(
237
+ feature.name, max(int(getattr(feature, "input_dim", 1)), 1)
238
+ )
203
239
  if value.shape[1] != expected_in_dim:
204
- raise ValueError(f"[EmbeddingLayer Error]:Dense feature '{feature.name}' expects {expected_in_dim} inputs but got {value.shape[1]}.")
240
+ raise ValueError(
241
+ f"[EmbeddingLayer Error]:Dense feature '{feature.name}' expects {expected_in_dim} inputs but got {value.shape[1]}."
242
+ )
205
243
  if not feature.use_embedding:
206
- return value
244
+ return value
207
245
  dense_layer = self.dense_transforms[feature.name]
208
246
  return dense_layer(value)
209
247
 
210
- def compute_output_dim(self, features: list[DenseFeature | SequenceFeature | SparseFeature] | None = None) -> int:
248
+ def compute_output_dim(
249
+ self,
250
+ features: list[DenseFeature | SequenceFeature | SparseFeature] | None = None,
251
+ ) -> int:
211
252
  candidates = list(features) if features is not None else self.features
212
- unique_feats = OrderedDict((feat.name, feat) for feat in candidates) # type: ignore[assignment]
253
+ unique_feats = OrderedDict((feat.name, feat) for feat in candidates) # type: ignore[assignment]
213
254
  dim = 0
214
255
  for feat in unique_feats.values():
215
256
  if isinstance(feat, DenseFeature):
@@ -220,35 +261,55 @@ class EmbeddingLayer(nn.Module):
220
261
  elif isinstance(feat, SequenceFeature) and feat.combiner == "concat":
221
262
  dim += feat.embedding_dim * feat.max_len
222
263
  else:
223
- dim += feat.embedding_dim # type: ignore[assignment]
264
+ dim += feat.embedding_dim # type: ignore[assignment]
224
265
  return dim
225
266
 
226
267
  def get_input_dim(self, features: list[object] | None = None) -> int:
227
- return self.compute_output_dim(features) # type: ignore[assignment]
268
+ return self.compute_output_dim(features) # type: ignore[assignment]
228
269
 
229
270
  @property
230
271
  def input_dim(self) -> int:
231
272
  return self.output_dim
232
273
 
274
+
233
275
  class InputMask(nn.Module):
234
276
  def __init__(self):
235
277
  super().__init__()
236
278
 
237
- def forward(self, x: dict[str, torch.Tensor], feature: SequenceFeature, seq_tensor: torch.Tensor | None = None):
238
- values = seq_tensor if seq_tensor is not None else x[feature.name]
239
- if feature.padding_idx is not None:
240
- mask = (values.long() != feature.padding_idx)
279
+ def forward(
280
+ self,
281
+ x: dict[str, torch.Tensor],
282
+ feature: SequenceFeature,
283
+ seq_tensor: torch.Tensor | None = None,
284
+ ):
285
+ if seq_tensor is not None:
286
+ values = seq_tensor
241
287
  else:
242
- mask = (values.long() != 0)
288
+ values = x[feature.name]
289
+ values = values.long()
290
+ padding_idx = feature.padding_idx if feature.padding_idx is not None else 0
291
+ mask = values != padding_idx
292
+
243
293
  if mask.dim() == 1:
244
- mask = mask.unsqueeze(-1)
245
- return mask.unsqueeze(1).float()
294
+ # [B] -> [B, 1, 1]
295
+ mask = mask.unsqueeze(1).unsqueeze(2)
296
+ elif mask.dim() == 2:
297
+ # [B, L] -> [B, 1, L]
298
+ mask = mask.unsqueeze(1)
299
+ elif mask.dim() == 3:
300
+ # [B, 1, L]
301
+ # [B, L, 1] -> [B, L] -> [B, 1, L]
302
+ if mask.size(1) != 1 and mask.size(2) == 1:
303
+ mask = mask.squeeze(-1).unsqueeze(1)
304
+ else:
305
+ raise ValueError(
306
+ f"InputMask only supports 1D/2D/3D tensors, got shape {values.shape}"
307
+ )
308
+ return mask.float()
309
+
246
310
 
247
311
  class LR(nn.Module):
248
- def __init__(
249
- self,
250
- input_dim: int,
251
- sigmoid: bool = False):
312
+ def __init__(self, input_dim: int, sigmoid: bool = False):
252
313
  super().__init__()
253
314
  self.sigmoid = sigmoid
254
315
  self.fc = nn.Linear(input_dim, 1, bias=True)
@@ -259,18 +320,24 @@ class LR(nn.Module):
259
320
  else:
260
321
  return self.fc(x)
261
322
 
323
+
262
324
  class ConcatPooling(nn.Module):
263
325
  def __init__(self):
264
326
  super().__init__()
265
327
 
266
- def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
267
- return x.flatten(start_dim=1, end_dim=2)
328
+ def forward(
329
+ self, x: torch.Tensor, mask: torch.Tensor | None = None
330
+ ) -> torch.Tensor:
331
+ return x.flatten(start_dim=1, end_dim=2)
332
+
268
333
 
269
334
  class AveragePooling(nn.Module):
270
335
  def __init__(self):
271
336
  super().__init__()
272
337
 
273
- def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
338
+ def forward(
339
+ self, x: torch.Tensor, mask: torch.Tensor | None = None
340
+ ) -> torch.Tensor:
274
341
  if mask is None:
275
342
  return torch.mean(x, dim=1)
276
343
  else:
@@ -278,54 +345,67 @@ class AveragePooling(nn.Module):
278
345
  non_padding_length = mask.sum(dim=-1)
279
346
  return sum_pooling_matrix / (non_padding_length.float() + 1e-16)
280
347
 
348
+
281
349
  class SumPooling(nn.Module):
282
350
  def __init__(self):
283
351
  super().__init__()
284
352
 
285
- def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
353
+ def forward(
354
+ self, x: torch.Tensor, mask: torch.Tensor | None = None
355
+ ) -> torch.Tensor:
286
356
  if mask is None:
287
357
  return torch.sum(x, dim=1)
288
358
  else:
289
359
  return torch.bmm(mask, x).squeeze(1)
290
360
 
361
+
291
362
  class MLP(nn.Module):
292
363
  def __init__(
293
- self,
294
- input_dim: int,
295
- output_layer: bool = True,
296
- dims: list[int] | None = None,
297
- dropout: float = 0.0,
298
- activation: str = "relu"):
364
+ self,
365
+ input_dim: int,
366
+ output_layer: bool = True,
367
+ dims: list[int] | None = None,
368
+ dropout: float = 0.0,
369
+ activation: str = "relu",
370
+ ):
299
371
  super().__init__()
300
372
  if dims is None:
301
373
  dims = []
302
- layers = list()
374
+ layers = []
375
+ current_dim = input_dim
376
+
303
377
  for i_dim in dims:
304
- layers.append(nn.Linear(input_dim, i_dim))
378
+ layers.append(nn.Linear(current_dim, i_dim))
305
379
  layers.append(nn.BatchNorm1d(i_dim))
306
380
  layers.append(activation_layer(activation))
307
381
  layers.append(nn.Dropout(p=dropout))
308
- input_dim = i_dim
382
+ current_dim = i_dim
383
+
309
384
  if output_layer:
310
- layers.append(nn.Linear(input_dim, 1))
385
+ layers.append(nn.Linear(current_dim, 1))
386
+ self.output_dim = 1
387
+ else:
388
+ self.output_dim = current_dim
311
389
  self.mlp = nn.Sequential(*layers)
312
390
 
313
391
  def forward(self, x):
314
392
  return self.mlp(x)
315
393
 
394
+
316
395
  class FM(nn.Module):
317
396
  def __init__(self, reduce_sum: bool = True):
318
397
  super().__init__()
319
398
  self.reduce_sum = reduce_sum
320
399
 
321
400
  def forward(self, x: torch.Tensor) -> torch.Tensor:
322
- square_of_sum = torch.sum(x, dim=1)**2
401
+ square_of_sum = torch.sum(x, dim=1) ** 2
323
402
  sum_of_square = torch.sum(x**2, dim=1)
324
403
  ix = square_of_sum - sum_of_square
325
404
  if self.reduce_sum:
326
405
  ix = torch.sum(ix, dim=1, keepdim=True)
327
406
  return 0.5 * ix
328
407
 
408
+
329
409
  class CrossLayer(nn.Module):
330
410
  def __init__(self, input_dim: int):
331
411
  super(CrossLayer, self).__init__()
@@ -336,60 +416,74 @@ class CrossLayer(nn.Module):
336
416
  x = self.w(x_i) * x_0 + self.b
337
417
  return x
338
418
 
419
+
339
420
  class SENETLayer(nn.Module):
340
- def __init__(
341
- self,
342
- num_fields: int,
343
- reduction_ratio: int = 3):
421
+ def __init__(self, num_fields: int, reduction_ratio: int = 3):
344
422
  super(SENETLayer, self).__init__()
345
- reduced_size = max(1, int(num_fields/ reduction_ratio))
346
- self.mlp = nn.Sequential(nn.Linear(num_fields, reduced_size, bias=False),
347
- nn.ReLU(),
348
- nn.Linear(reduced_size, num_fields, bias=False),
349
- nn.ReLU())
423
+ reduced_size = max(1, int(num_fields / reduction_ratio))
424
+ self.mlp = nn.Sequential(
425
+ nn.Linear(num_fields, reduced_size, bias=False),
426
+ nn.ReLU(),
427
+ nn.Linear(reduced_size, num_fields, bias=False),
428
+ nn.ReLU(),
429
+ )
430
+
350
431
  def forward(self, x: torch.Tensor) -> torch.Tensor:
351
432
  z = torch.mean(x, dim=-1, out=None)
352
433
  a = self.mlp(z)
353
- v = x*a.unsqueeze(-1)
434
+ v = x * a.unsqueeze(-1)
354
435
  return v
355
436
 
437
+
356
438
  class BiLinearInteractionLayer(nn.Module):
357
439
  def __init__(
358
- self,
359
- input_dim: int,
360
- num_fields: int,
361
- bilinear_type: str = "field_interaction"):
440
+ self, input_dim: int, num_fields: int, bilinear_type: str = "field_interaction"
441
+ ):
362
442
  super(BiLinearInteractionLayer, self).__init__()
363
443
  self.bilinear_type = bilinear_type
364
444
  if self.bilinear_type == "field_all":
365
445
  self.bilinear_layer = nn.Linear(input_dim, input_dim, bias=False)
366
446
  elif self.bilinear_type == "field_each":
367
- self.bilinear_layer = nn.ModuleList([nn.Linear(input_dim, input_dim, bias=False) for i in range(num_fields)])
447
+ self.bilinear_layer = nn.ModuleList(
448
+ [nn.Linear(input_dim, input_dim, bias=False) for i in range(num_fields)]
449
+ )
368
450
  elif self.bilinear_type == "field_interaction":
369
- self.bilinear_layer = nn.ModuleList([nn.Linear(input_dim, input_dim, bias=False) for i,j in combinations(range(num_fields), 2)])
451
+ self.bilinear_layer = nn.ModuleList(
452
+ [
453
+ nn.Linear(input_dim, input_dim, bias=False)
454
+ for i, j in combinations(range(num_fields), 2)
455
+ ]
456
+ )
370
457
  else:
371
458
  raise NotImplementedError()
372
459
 
373
460
  def forward(self, x: torch.Tensor) -> torch.Tensor:
374
461
  feature_emb = torch.split(x, 1, dim=1)
375
462
  if self.bilinear_type == "field_all":
376
- bilinear_list = [self.bilinear_layer(v_i)*v_j for v_i, v_j in combinations(feature_emb, 2)]
463
+ bilinear_list = [
464
+ self.bilinear_layer(v_i) * v_j
465
+ for v_i, v_j in combinations(feature_emb, 2)
466
+ ]
377
467
  elif self.bilinear_type == "field_each":
378
- bilinear_list = [self.bilinear_layer[i](feature_emb[i])*feature_emb[j] for i,j in combinations(range(len(feature_emb)), 2)] # type: ignore[assignment]
468
+ bilinear_list = [self.bilinear_layer[i](feature_emb[i]) * feature_emb[j] for i, j in combinations(range(len(feature_emb)), 2)] # type: ignore[assignment]
379
469
  elif self.bilinear_type == "field_interaction":
380
- bilinear_list = [self.bilinear_layer[i](v[0])*v[1] for i,v in enumerate(combinations(feature_emb, 2))] # type: ignore[assignment]
470
+ bilinear_list = [self.bilinear_layer[i](v[0]) * v[1] for i, v in enumerate(combinations(feature_emb, 2))] # type: ignore[assignment]
381
471
  return torch.cat(bilinear_list, dim=1)
382
472
 
473
+
383
474
  class MultiHeadSelfAttention(nn.Module):
384
475
  def __init__(
385
- self,
386
- embedding_dim: int,
387
- num_heads: int = 2,
388
- dropout: float = 0.0,
389
- use_residual: bool = True):
476
+ self,
477
+ embedding_dim: int,
478
+ num_heads: int = 2,
479
+ dropout: float = 0.0,
480
+ use_residual: bool = True,
481
+ ):
390
482
  super().__init__()
391
483
  if embedding_dim % num_heads != 0:
392
- raise ValueError(f"[MultiHeadSelfAttention Error]: embedding_dim ({embedding_dim}) must be divisible by num_heads ({num_heads})")
484
+ raise ValueError(
485
+ f"[MultiHeadSelfAttention Error]: embedding_dim ({embedding_dim}) must be divisible by num_heads ({num_heads})"
486
+ )
393
487
  self.embedding_dim = embedding_dim
394
488
  self.num_heads = num_heads
395
489
  self.head_dim = embedding_dim // num_heads
@@ -400,24 +494,34 @@ class MultiHeadSelfAttention(nn.Module):
400
494
  if self.use_residual:
401
495
  self.W_Res = nn.Linear(embedding_dim, embedding_dim, bias=False)
402
496
  self.dropout = nn.Dropout(dropout)
403
-
497
+
404
498
  def forward(self, x: torch.Tensor) -> torch.Tensor:
405
499
  batch_size, num_fields, _ = x.shape
406
500
  Q = self.W_Q(x) # [batch_size, num_fields, embedding_dim]
407
501
  K = self.W_K(x)
408
502
  V = self.W_V(x)
409
503
  # Split into multiple heads: [batch_size, num_heads, num_fields, head_dim]
410
- Q = Q.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
411
- K = K.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
412
- V = V.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
504
+ Q = Q.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
505
+ 1, 2
506
+ )
507
+ K = K.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
508
+ 1, 2
509
+ )
510
+ V = V.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
511
+ 1, 2
512
+ )
413
513
  # Attention scores
414
- scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
514
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim**0.5)
415
515
  attention_weights = F.softmax(scores, dim=-1)
416
516
  attention_weights = self.dropout(attention_weights)
417
- attention_output = torch.matmul(attention_weights, V) # [batch_size, num_heads, num_fields, head_dim]
517
+ attention_output = torch.matmul(
518
+ attention_weights, V
519
+ ) # [batch_size, num_heads, num_fields, head_dim]
418
520
  # Concatenate heads
419
521
  attention_output = attention_output.transpose(1, 2).contiguous()
420
- attention_output = attention_output.view(batch_size, num_fields, self.embedding_dim)
522
+ attention_output = attention_output.view(
523
+ batch_size, num_fields, self.embedding_dim
524
+ )
421
525
  # Residual connection
422
526
  if self.use_residual:
423
527
  output = attention_output + self.W_Res(x)
@@ -426,17 +530,20 @@ class MultiHeadSelfAttention(nn.Module):
426
530
  output = F.relu(output)
427
531
  return output
428
532
 
533
+
429
534
  class AttentionPoolingLayer(nn.Module):
430
535
  """
431
536
  Attention pooling layer for DIN/DIEN
432
537
  Computes attention weights between query (candidate item) and keys (user behavior sequence)
433
538
  """
539
+
434
540
  def __init__(
435
- self,
436
- embedding_dim: int,
437
- hidden_units: list = [80, 40],
438
- activation: str ='sigmoid',
439
- use_softmax: bool = True):
541
+ self,
542
+ embedding_dim: int,
543
+ hidden_units: list = [80, 40],
544
+ activation: str = "sigmoid",
545
+ use_softmax: bool = True,
546
+ ):
440
547
  super().__init__()
441
548
  self.embedding_dim = embedding_dim
442
549
  self.use_softmax = use_softmax
@@ -450,8 +557,14 @@ class AttentionPoolingLayer(nn.Module):
450
557
  input_dim = hidden_unit
451
558
  layers.append(nn.Linear(input_dim, 1))
452
559
  self.attention_net = nn.Sequential(*layers)
453
-
454
- def forward(self, query: torch.Tensor, keys: torch.Tensor, keys_length: torch.Tensor | None = None, mask: torch.Tensor | None = None):
560
+
561
+ def forward(
562
+ self,
563
+ query: torch.Tensor,
564
+ keys: torch.Tensor,
565
+ keys_length: torch.Tensor | None = None,
566
+ mask: torch.Tensor | None = None,
567
+ ):
455
568
  """
456
569
  Args:
457
570
  query: [batch_size, embedding_dim] - candidate item embedding
@@ -462,28 +575,46 @@ class AttentionPoolingLayer(nn.Module):
462
575
  output: [batch_size, embedding_dim] - attention pooled representation
463
576
  """
464
577
  batch_size, sequence_length, embedding_dim = keys.shape
465
- assert query.shape == (batch_size, embedding_dim), f"query shape {query.shape} != ({batch_size}, {embedding_dim})"
578
+ assert query.shape == (
579
+ batch_size,
580
+ embedding_dim,
581
+ ), f"query shape {query.shape} != ({batch_size}, {embedding_dim})"
466
582
  if mask is None and keys_length is not None:
467
583
  # keys_length: (batch_size,)
468
584
  device = keys.device
469
- seq_range = torch.arange(sequence_length, device=device).unsqueeze(0) # (1, sequence_length)
585
+ seq_range = torch.arange(sequence_length, device=device).unsqueeze(
586
+ 0
587
+ ) # (1, sequence_length)
470
588
  mask = (seq_range < keys_length.unsqueeze(1)).unsqueeze(-1).float()
471
589
  if mask is not None:
472
590
  if mask.dim() == 2:
473
591
  # (B, L)
474
592
  mask = mask.unsqueeze(-1)
475
- elif mask.dim() == 3 and mask.shape[1] == 1 and mask.shape[2] == sequence_length:
593
+ elif (
594
+ mask.dim() == 3
595
+ and mask.shape[1] == 1
596
+ and mask.shape[2] == sequence_length
597
+ ):
476
598
  # (B, 1, L) -> (B, L, 1)
477
599
  mask = mask.transpose(1, 2)
478
- elif mask.dim() == 3 and mask.shape[1] == sequence_length and mask.shape[2] == 1:
600
+ elif (
601
+ mask.dim() == 3
602
+ and mask.shape[1] == sequence_length
603
+ and mask.shape[2] == 1
604
+ ):
479
605
  pass
480
606
  else:
481
- raise ValueError(f"[AttentionPoolingLayer Error]: Unsupported mask shape: {mask.shape}")
607
+ raise ValueError(
608
+ f"[AttentionPoolingLayer Error]: Unsupported mask shape: {mask.shape}"
609
+ )
482
610
  mask = mask.to(keys.dtype)
483
611
  # Expand query to (B, L, D)
484
612
  query_expanded = query.unsqueeze(1).expand(-1, sequence_length, -1)
485
613
  # [query, key, query-key, query*key] -> (B, L, 4D)
486
- attention_input = torch.cat([query_expanded, keys, query_expanded - keys, query_expanded * keys], dim=-1,)
614
+ attention_input = torch.cat(
615
+ [query_expanded, keys, query_expanded - keys, query_expanded * keys],
616
+ dim=-1,
617
+ )
487
618
  attention_scores = self.attention_net(attention_input)
488
619
  if mask is not None:
489
620
  attention_scores = attention_scores.masked_fill(mask == 0, -1e9)