nextrec 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. nextrec/__init__.py +41 -0
  2. nextrec/__version__.py +1 -0
  3. nextrec/basic/__init__.py +0 -0
  4. nextrec/basic/activation.py +92 -0
  5. nextrec/basic/callback.py +35 -0
  6. nextrec/basic/dataloader.py +447 -0
  7. nextrec/basic/features.py +87 -0
  8. nextrec/basic/layers.py +985 -0
  9. nextrec/basic/loggers.py +124 -0
  10. nextrec/basic/metrics.py +557 -0
  11. nextrec/basic/model.py +1438 -0
  12. nextrec/data/__init__.py +27 -0
  13. nextrec/data/data_utils.py +132 -0
  14. nextrec/data/preprocessor.py +662 -0
  15. nextrec/loss/__init__.py +35 -0
  16. nextrec/loss/loss_utils.py +136 -0
  17. nextrec/loss/match_losses.py +294 -0
  18. nextrec/models/generative/hstu.py +0 -0
  19. nextrec/models/generative/tiger.py +0 -0
  20. nextrec/models/match/__init__.py +13 -0
  21. nextrec/models/match/dssm.py +200 -0
  22. nextrec/models/match/dssm_v2.py +162 -0
  23. nextrec/models/match/mind.py +210 -0
  24. nextrec/models/match/sdm.py +253 -0
  25. nextrec/models/match/youtube_dnn.py +172 -0
  26. nextrec/models/multi_task/esmm.py +129 -0
  27. nextrec/models/multi_task/mmoe.py +161 -0
  28. nextrec/models/multi_task/ple.py +260 -0
  29. nextrec/models/multi_task/share_bottom.py +126 -0
  30. nextrec/models/ranking/__init__.py +17 -0
  31. nextrec/models/ranking/afm.py +118 -0
  32. nextrec/models/ranking/autoint.py +140 -0
  33. nextrec/models/ranking/dcn.py +120 -0
  34. nextrec/models/ranking/deepfm.py +95 -0
  35. nextrec/models/ranking/dien.py +214 -0
  36. nextrec/models/ranking/din.py +181 -0
  37. nextrec/models/ranking/fibinet.py +130 -0
  38. nextrec/models/ranking/fm.py +87 -0
  39. nextrec/models/ranking/masknet.py +125 -0
  40. nextrec/models/ranking/pnn.py +128 -0
  41. nextrec/models/ranking/widedeep.py +105 -0
  42. nextrec/models/ranking/xdeepfm.py +117 -0
  43. nextrec/utils/__init__.py +18 -0
  44. nextrec/utils/common.py +14 -0
  45. nextrec/utils/embedding.py +19 -0
  46. nextrec/utils/initializer.py +47 -0
  47. nextrec/utils/optimizer.py +75 -0
  48. nextrec-0.1.1.dist-info/METADATA +302 -0
  49. nextrec-0.1.1.dist-info/RECORD +51 -0
  50. nextrec-0.1.1.dist-info/WHEEL +4 -0
  51. nextrec-0.1.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,985 @@
1
+ """
2
+ Layer implementations used across NextRec models.
3
+
4
+ Date: create on 27/10/2025, update on 19/11/2025
5
+ Author:
6
+ Yang Zhou,zyaztec@gmail.com
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from itertools import combinations
12
+ from typing import Iterable, Sequence, Union
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ from nextrec.basic.activation import activation_layer
19
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
20
+ from nextrec.utils.initializer import get_initializer_fn
21
+
22
+ Feature = Union[DenseFeature, SparseFeature, SequenceFeature]
23
+
24
+ __all__ = [
25
+ "PredictionLayer",
26
+ "EmbeddingLayer",
27
+ "InputMask",
28
+ "LR",
29
+ "ConcatPooling",
30
+ "AveragePooling",
31
+ "SumPooling",
32
+ "MLP",
33
+ "FM",
34
+ "FFM",
35
+ "CEN",
36
+ "CIN",
37
+ "CrossLayer",
38
+ "CrossNetwork",
39
+ "CrossNetV2",
40
+ "CrossNetMix",
41
+ "SENETLayer",
42
+ "BiLinearInteractionLayer",
43
+ "MultiInterestSA",
44
+ "CapsuleNetwork",
45
+ "MultiHeadSelfAttention",
46
+ "AttentionPoolingLayer",
47
+ "DynamicGRU",
48
+ "AUGRU",
49
+ ]
50
+
51
+
52
+ class PredictionLayer(nn.Module):
53
+ _CLASSIFICATION_TASKS = {"classification", "binary", "ctr", "ranking", "match", "matching"}
54
+ _REGRESSION_TASKS = {"regression", "continuous"}
55
+ _MULTICLASS_TASKS = {"multiclass", "softmax"}
56
+
57
+ def __init__(
58
+ self,
59
+ task_type: Union[str, Sequence[str]] = "binary",
60
+ task_dims: Union[int, Sequence[int], None] = None,
61
+ use_bias: bool = True,
62
+ return_logits: bool = False,
63
+ ):
64
+ super().__init__()
65
+
66
+ if isinstance(task_type, str):
67
+ self.task_types = [task_type]
68
+ else:
69
+ self.task_types = list(task_type)
70
+
71
+ if len(self.task_types) == 0:
72
+ raise ValueError("At least one task_type must be specified.")
73
+
74
+ if task_dims is None:
75
+ dims = [1] * len(self.task_types)
76
+ elif isinstance(task_dims, int):
77
+ dims = [task_dims]
78
+ else:
79
+ dims = list(task_dims)
80
+
81
+ if len(dims) not in (1, len(self.task_types)):
82
+ raise ValueError(
83
+ "task_dims must be None, a single int (shared), or a sequence of the same length as task_type."
84
+ )
85
+
86
+ if len(dims) == 1 and len(self.task_types) > 1:
87
+ dims = dims * len(self.task_types)
88
+
89
+ self.task_dims = dims
90
+ self.total_dim = sum(self.task_dims)
91
+ self.return_logits = return_logits
92
+
93
+ # Keep slice offsets per task
94
+ start = 0
95
+ self._task_slices: list[tuple[int, int]] = []
96
+ for dim in self.task_dims:
97
+ if dim < 1:
98
+ raise ValueError("Each task dimension must be >= 1.")
99
+ self._task_slices.append((start, start + dim))
100
+ start += dim
101
+
102
+ if use_bias:
103
+ self.bias = nn.Parameter(torch.zeros(self.total_dim))
104
+ else:
105
+ self.register_parameter("bias", None)
106
+
107
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
108
+ if x.dim() == 1:
109
+ x = x.unsqueeze(-1)
110
+
111
+ if x.shape[-1] != self.total_dim:
112
+ raise ValueError(
113
+ f"Input last dimension ({x.shape[-1]}) does not match expected total dimension ({self.total_dim})."
114
+ )
115
+
116
+ logits = x if self.bias is None else x + self.bias
117
+ outputs: list[torch.Tensor] = []
118
+
119
+ for task_type, (start, end) in zip(self.task_types, self._task_slices):
120
+ task_logits = logits[..., start:end]
121
+ if self.return_logits:
122
+ outputs.append(task_logits)
123
+ continue
124
+
125
+ activation = self._get_activation(task_type)
126
+ outputs.append(activation(task_logits))
127
+
128
+ result = torch.cat(outputs, dim=-1)
129
+ if result.shape[-1] == 1:
130
+ result = result.squeeze(-1)
131
+ return result
132
+
133
+ def _get_activation(self, task_type: str):
134
+ task = task_type.lower()
135
+ if task in self._CLASSIFICATION_TASKS:
136
+ return torch.sigmoid
137
+ if task in self._REGRESSION_TASKS:
138
+ return lambda x: x
139
+ if task in self._MULTICLASS_TASKS:
140
+ return lambda x: torch.softmax(x, dim=-1)
141
+ raise ValueError(f"Unsupported task_type '{task_type}'.")
142
+
143
+
144
+ class EmbeddingLayer(nn.Module):
145
+ def __init__(self, features: Sequence[Feature]):
146
+ super().__init__()
147
+ self.features = list(features)
148
+ self.embed_dict = nn.ModuleDict()
149
+ self.dense_transforms = nn.ModuleDict()
150
+ self.dense_input_dims: dict[str, int] = {}
151
+
152
+ for feature in self.features:
153
+ if isinstance(feature, (SparseFeature, SequenceFeature)):
154
+ if feature.embedding_name in self.embed_dict:
155
+ continue
156
+
157
+ embedding = nn.Embedding(
158
+ num_embeddings=feature.vocab_size,
159
+ embedding_dim=feature.embedding_dim,
160
+ padding_idx=feature.padding_idx,
161
+ )
162
+ embedding.weight.requires_grad = feature.trainable
163
+
164
+ initialization = get_initializer_fn(
165
+ init_type=feature.init_type,
166
+ activation="linear",
167
+ param=feature.init_params,
168
+ )
169
+ initialization(embedding.weight)
170
+ self.embed_dict[feature.embedding_name] = embedding
171
+
172
+ elif isinstance(feature, DenseFeature):
173
+ if feature.name in self.dense_transforms:
174
+ continue
175
+ in_dim = max(int(getattr(feature, "input_dim", 1)), 1)
176
+ out_dim = max(int(getattr(feature, "embedding_dim", None) or in_dim), 1)
177
+ dense_linear = nn.Linear(in_dim, out_dim, bias=True)
178
+ nn.init.xavier_uniform_(dense_linear.weight)
179
+ nn.init.zeros_(dense_linear.bias)
180
+ self.dense_transforms[feature.name] = dense_linear
181
+ self.dense_input_dims[feature.name] = in_dim
182
+
183
+ else:
184
+ raise TypeError(f"Unsupported feature type: {type(feature)}")
185
+ self.output_dim = self._compute_output_dim()
186
+
187
+ def forward(
188
+ self,
189
+ x: dict[str, torch.Tensor],
190
+ features: Sequence[Feature],
191
+ squeeze_dim: bool = False,
192
+ ) -> torch.Tensor:
193
+ sparse_embeds: list[torch.Tensor] = []
194
+ dense_embeds: list[torch.Tensor] = []
195
+
196
+ for feature in features:
197
+ if isinstance(feature, SparseFeature):
198
+ embed = self.embed_dict[feature.embedding_name]
199
+ sparse_embeds.append(embed(x[feature.name].long()).unsqueeze(1))
200
+
201
+ elif isinstance(feature, SequenceFeature):
202
+ seq_input = x[feature.name].long()
203
+ if feature.max_len is not None and seq_input.size(1) > feature.max_len:
204
+ seq_input = seq_input[:, -feature.max_len :]
205
+
206
+ embed = self.embed_dict[feature.embedding_name]
207
+ seq_emb = embed(seq_input) # [B, seq_len, emb_dim]
208
+
209
+ if feature.combiner == "mean":
210
+ pooling_layer = AveragePooling()
211
+ elif feature.combiner == "sum":
212
+ pooling_layer = SumPooling()
213
+ elif feature.combiner == "concat":
214
+ pooling_layer = ConcatPooling()
215
+ else:
216
+ raise ValueError(f"Unknown combiner for {feature.name}: {feature.combiner}")
217
+
218
+ feature_mask = InputMask()(x, feature, seq_input)
219
+ sparse_embeds.append(pooling_layer(seq_emb, feature_mask).unsqueeze(1))
220
+
221
+ elif isinstance(feature, DenseFeature):
222
+ dense_embeds.append(self._project_dense(feature, x))
223
+
224
+ if squeeze_dim:
225
+ flattened_sparse = [emb.flatten(start_dim=1) for emb in sparse_embeds]
226
+ pieces = []
227
+ if flattened_sparse:
228
+ pieces.append(torch.cat(flattened_sparse, dim=1))
229
+ if dense_embeds:
230
+ pieces.append(torch.cat(dense_embeds, dim=1))
231
+
232
+ if not pieces:
233
+ raise ValueError("No input features found for EmbeddingLayer.")
234
+
235
+ return pieces[0] if len(pieces) == 1 else torch.cat(pieces, dim=1)
236
+
237
+ # squeeze_dim=False requires embeddings with identical last dimension
238
+ output_embeddings = list(sparse_embeds)
239
+ if dense_embeds:
240
+ target_dim = None
241
+ if output_embeddings:
242
+ target_dim = output_embeddings[0].shape[-1]
243
+ elif len({emb.shape[-1] for emb in dense_embeds}) == 1:
244
+ target_dim = dense_embeds[0].shape[-1]
245
+
246
+ if target_dim is not None:
247
+ aligned_dense = [
248
+ emb.unsqueeze(1) for emb in dense_embeds if emb.shape[-1] == target_dim
249
+ ]
250
+ output_embeddings.extend(aligned_dense)
251
+
252
+ if not output_embeddings:
253
+ raise ValueError(
254
+ "squeeze_dim=False requires at least one sparse/sequence feature or "
255
+ "dense features with identical projected dimensions."
256
+ )
257
+
258
+ return torch.cat(output_embeddings, dim=1)
259
+
260
+ def _project_dense(self, feature: DenseFeature, x: dict[str, torch.Tensor]) -> torch.Tensor:
261
+ if feature.name not in x:
262
+ raise KeyError(f"Dense feature '{feature.name}' is missing from input.")
263
+
264
+ value = x[feature.name].float()
265
+ if value.dim() == 1:
266
+ value = value.unsqueeze(-1)
267
+ else:
268
+ value = value.view(value.size(0), -1)
269
+
270
+ dense_layer = self.dense_transforms[feature.name]
271
+ expected_in_dim = self.dense_input_dims[feature.name]
272
+ if value.shape[1] != expected_in_dim:
273
+ raise ValueError(
274
+ f"Dense feature '{feature.name}' expects {expected_in_dim} inputs but "
275
+ f"got {value.shape[1]}."
276
+ )
277
+
278
+ return dense_layer(value)
279
+
280
+ def _compute_output_dim(self):
281
+ return
282
+
283
+ class InputMask(nn.Module):
284
+ """Utility module to build sequence masks for pooling layers."""
285
+
286
+ def __init__(self):
287
+ super().__init__()
288
+
289
+ def forward(self, x, fea, seq_tensor=None):
290
+ values = seq_tensor if seq_tensor is not None else x[fea.name]
291
+ if fea.padding_idx is not None:
292
+ mask = (values.long() != fea.padding_idx)
293
+ else:
294
+ mask = (values.long() != 0)
295
+ if mask.dim() == 1:
296
+ mask = mask.unsqueeze(-1)
297
+ return mask.unsqueeze(1).float()
298
+
299
+
300
+ class LR(nn.Module):
301
+ """Wide component from Wide&Deep (Cheng et al., 2016)."""
302
+
303
+ def __init__(self, input_dim, sigmoid=False):
304
+ super().__init__()
305
+ self.sigmoid = sigmoid
306
+ self.fc = nn.Linear(input_dim, 1, bias=True)
307
+
308
+ def forward(self, x):
309
+ if self.sigmoid:
310
+ return torch.sigmoid(self.fc(x))
311
+ else:
312
+ return self.fc(x)
313
+
314
+
315
+ class ConcatPooling(nn.Module):
316
+ """Concatenates sequence embeddings along the temporal dimension."""
317
+
318
+ def __init__(self):
319
+ super().__init__()
320
+
321
+ def forward(self, x, mask=None):
322
+ return x.flatten(start_dim=1, end_dim=2)
323
+
324
+
325
+ class AveragePooling(nn.Module):
326
+ """Mean pooling with optional padding mask."""
327
+
328
+ def __init__(self):
329
+ super().__init__()
330
+
331
+ def forward(self, x, mask=None):
332
+ if mask is None:
333
+ return torch.mean(x, dim=1)
334
+ else:
335
+ sum_pooling_matrix = torch.bmm(mask, x).squeeze(1)
336
+ non_padding_length = mask.sum(dim=-1)
337
+ return sum_pooling_matrix / (non_padding_length.float() + 1e-16)
338
+
339
+
340
+ class SumPooling(nn.Module):
341
+ """Sum pooling with optional padding mask."""
342
+
343
+ def __init__(self):
344
+ super().__init__()
345
+
346
+ def forward(self, x, mask=None):
347
+ if mask is None:
348
+ return torch.sum(x, dim=1)
349
+ else:
350
+ return torch.bmm(mask, x).squeeze(1)
351
+
352
+
353
+ class MLP(nn.Module):
354
+ """Stacked fully connected layers used in the deep component."""
355
+
356
+ def __init__(self, input_dim, output_layer=True, dims=None, dropout=0, activation="relu"):
357
+ super().__init__()
358
+ if dims is None:
359
+ dims = []
360
+ layers = list()
361
+ for i_dim in dims:
362
+ layers.append(nn.Linear(input_dim, i_dim))
363
+ layers.append(nn.BatchNorm1d(i_dim))
364
+ layers.append(activation_layer(activation))
365
+ layers.append(nn.Dropout(p=dropout))
366
+ input_dim = i_dim
367
+ if output_layer:
368
+ layers.append(nn.Linear(input_dim, 1))
369
+ self.mlp = nn.Sequential(*layers)
370
+
371
+ def forward(self, x):
372
+ return self.mlp(x)
373
+
374
+
375
+ class FM(nn.Module):
376
+ """Factorization Machine (Rendle, 2010) second-order interaction term."""
377
+
378
+ def __init__(self, reduce_sum=True):
379
+ super().__init__()
380
+ self.reduce_sum = reduce_sum
381
+
382
+ def forward(self, x):
383
+ square_of_sum = torch.sum(x, dim=1)**2
384
+ sum_of_square = torch.sum(x**2, dim=1)
385
+ ix = square_of_sum - sum_of_square
386
+ if self.reduce_sum:
387
+ ix = torch.sum(ix, dim=1, keepdim=True)
388
+ return 0.5 * ix
389
+
390
+
391
+ class CIN(nn.Module):
392
+ """Compressed Interaction Network from xDeepFM (Lian et al., 2018)."""
393
+
394
+ def __init__(self, input_dim, cin_size, split_half=True):
395
+ super().__init__()
396
+ self.num_layers = len(cin_size)
397
+ self.split_half = split_half
398
+ self.conv_layers = torch.nn.ModuleList()
399
+ prev_dim, fc_input_dim = input_dim, 0
400
+ for i in range(self.num_layers):
401
+ cross_layer_size = cin_size[i]
402
+ self.conv_layers.append(torch.nn.Conv1d(input_dim * prev_dim, cross_layer_size, 1, stride=1, dilation=1, bias=True))
403
+ if self.split_half and i != self.num_layers - 1:
404
+ cross_layer_size //= 2
405
+ prev_dim = cross_layer_size
406
+ fc_input_dim += prev_dim
407
+ self.fc = torch.nn.Linear(fc_input_dim, 1)
408
+
409
+ def forward(self, x):
410
+ xs = list()
411
+ x0, h = x.unsqueeze(2), x
412
+ for i in range(self.num_layers):
413
+ x = x0 * h.unsqueeze(1)
414
+ batch_size, f0_dim, fin_dim, embed_dim = x.shape
415
+ x = x.view(batch_size, f0_dim * fin_dim, embed_dim)
416
+ x = F.relu(self.conv_layers[i](x))
417
+ if self.split_half and i != self.num_layers - 1:
418
+ x, h = torch.split(x, x.shape[1] // 2, dim=1)
419
+ else:
420
+ h = x
421
+ xs.append(x)
422
+ return self.fc(torch.sum(torch.cat(xs, dim=1), 2))
423
+
424
+ class CrossLayer(nn.Module):
425
+ """Single cross layer used in DCN (Wang et al., 2017)."""
426
+
427
+ def __init__(self, input_dim):
428
+ super(CrossLayer, self).__init__()
429
+ self.w = torch.nn.Linear(input_dim, 1, bias=False)
430
+ self.b = torch.nn.Parameter(torch.zeros(input_dim))
431
+
432
+ def forward(self, x_0, x_i):
433
+ x = self.w(x_i) * x_0 + self.b
434
+ return x
435
+
436
+
437
+ class CrossNetwork(nn.Module):
438
+ """Stacked Cross Layers from DCN (Wang et al., 2017)."""
439
+
440
+ def __init__(self, input_dim, num_layers):
441
+ super().__init__()
442
+ self.num_layers = num_layers
443
+ self.w = torch.nn.ModuleList([torch.nn.Linear(input_dim, 1, bias=False) for _ in range(num_layers)])
444
+ self.b = torch.nn.ParameterList([torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)])
445
+
446
+ def forward(self, x):
447
+ """
448
+ :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
449
+ """
450
+ x0 = x
451
+ for i in range(self.num_layers):
452
+ xw = self.w[i](x)
453
+ x = x0 * xw + self.b[i] + x
454
+ return x
455
+
456
+ class CrossNetV2(nn.Module):
457
+ """Vector-wise cross network proposed in DCN V2 (Wang et al., 2021)."""
458
+ def __init__(self, input_dim, num_layers):
459
+ super().__init__()
460
+ self.num_layers = num_layers
461
+ self.w = torch.nn.ModuleList([torch.nn.Linear(input_dim, input_dim, bias=False) for _ in range(num_layers)])
462
+ self.b = torch.nn.ParameterList([torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)])
463
+
464
+
465
+ def forward(self, x):
466
+ x0 = x
467
+ for i in range(self.num_layers):
468
+ x =x0*self.w[i](x) + self.b[i] + x
469
+ return x
470
+
471
+ class CrossNetMix(nn.Module):
472
+ """Mixture of low-rank cross experts from DCN V2 (Wang et al., 2021)."""
473
+
474
+ def __init__(self, input_dim, num_layers=2, low_rank=32, num_experts=4):
475
+ super(CrossNetMix, self).__init__()
476
+ self.num_layers = num_layers
477
+ self.num_experts = num_experts
478
+
479
+ # U: (input_dim, low_rank)
480
+ self.u_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
481
+ torch.empty(num_experts, input_dim, low_rank))) for i in range(self.num_layers)])
482
+ # V: (input_dim, low_rank)
483
+ self.v_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
484
+ torch.empty(num_experts, input_dim, low_rank))) for i in range(self.num_layers)])
485
+ # C: (low_rank, low_rank)
486
+ self.c_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
487
+ torch.empty(num_experts, low_rank, low_rank))) for i in range(self.num_layers)])
488
+ self.gating = nn.ModuleList([nn.Linear(input_dim, 1, bias=False) for i in range(self.num_experts)])
489
+
490
+ self.bias = torch.nn.ParameterList([nn.Parameter(nn.init.zeros_(
491
+ torch.empty(input_dim, 1))) for i in range(self.num_layers)])
492
+
493
+ def forward(self, x):
494
+ x_0 = x.unsqueeze(2) # (bs, in_features, 1)
495
+ x_l = x_0
496
+ for i in range(self.num_layers):
497
+ output_of_experts = []
498
+ gating_score_experts = []
499
+ for expert_id in range(self.num_experts):
500
+ # (1) G(x_l)
501
+ # compute the gating score by x_l
502
+ gating_score_experts.append(self.gating[expert_id](x_l.squeeze(2)))
503
+
504
+ # (2) E(x_l)
505
+ # project the input x_l to $\mathbb{R}^{r}$
506
+ v_x = torch.matmul(self.v_list[i][expert_id].t(), x_l) # (bs, low_rank, 1)
507
+
508
+ # nonlinear activation in low rank space
509
+ v_x = torch.tanh(v_x)
510
+ v_x = torch.matmul(self.c_list[i][expert_id], v_x)
511
+ v_x = torch.tanh(v_x)
512
+
513
+ # project back to $\mathbb{R}^{d}$
514
+ uv_x = torch.matmul(self.u_list[i][expert_id], v_x) # (bs, in_features, 1)
515
+
516
+ dot_ = uv_x + self.bias[i]
517
+ dot_ = x_0 * dot_ # Hadamard-product
518
+
519
+ output_of_experts.append(dot_.squeeze(2))
520
+
521
+ # (3) mixture of low-rank experts
522
+ output_of_experts = torch.stack(output_of_experts, 2) # (bs, in_features, num_experts)
523
+ gating_score_experts = torch.stack(gating_score_experts, 1) # (bs, num_experts, 1)
524
+ moe_out = torch.matmul(output_of_experts, gating_score_experts.softmax(1))
525
+ x_l = moe_out + x_l # (bs, in_features, 1)
526
+
527
+ x_l = x_l.squeeze() # (bs, in_features)
528
+ return x_l
529
+
530
+ class SENETLayer(nn.Module):
531
+ """Squeeze-and-Excitation block adopted by FiBiNET (Huang et al., 2019)."""
532
+
533
+ def __init__(self, num_fields, reduction_ratio=3):
534
+ super(SENETLayer, self).__init__()
535
+ reduced_size = max(1, int(num_fields/ reduction_ratio))
536
+ self.mlp = nn.Sequential(nn.Linear(num_fields, reduced_size, bias=False),
537
+ nn.ReLU(),
538
+ nn.Linear(reduced_size, num_fields, bias=False),
539
+ nn.ReLU())
540
+ def forward(self, x):
541
+ z = torch.mean(x, dim=-1, out=None)
542
+ a = self.mlp(z)
543
+ v = x*a.unsqueeze(-1)
544
+ return v
545
+
546
+ class BiLinearInteractionLayer(nn.Module):
547
+ """Bilinear feature interaction from FiBiNET (Huang et al., 2019)."""
548
+
549
+ def __init__(self, input_dim, num_fields, bilinear_type = "field_interaction"):
550
+ super(BiLinearInteractionLayer, self).__init__()
551
+ self.bilinear_type = bilinear_type
552
+ if self.bilinear_type == "field_all":
553
+ self.bilinear_layer = nn.Linear(input_dim, input_dim, bias=False)
554
+ elif self.bilinear_type == "field_each":
555
+ self.bilinear_layer = nn.ModuleList([nn.Linear(input_dim, input_dim, bias=False) for i in range(num_fields)])
556
+ elif self.bilinear_type == "field_interaction":
557
+ self.bilinear_layer = nn.ModuleList([nn.Linear(input_dim, input_dim, bias=False) for i,j in combinations(range(num_fields), 2)])
558
+ else:
559
+ raise NotImplementedError()
560
+
561
+ def forward(self, x):
562
+ feature_emb = torch.split(x, 1, dim=1)
563
+ if self.bilinear_type == "field_all":
564
+ bilinear_list = [self.bilinear_layer(v_i)*v_j for v_i, v_j in combinations(feature_emb, 2)]
565
+ elif self.bilinear_type == "field_each":
566
+ bilinear_list = [self.bilinear_layer[i](feature_emb[i])*feature_emb[j] for i,j in combinations(range(len(feature_emb)), 2)]
567
+ elif self.bilinear_type == "field_interaction":
568
+ bilinear_list = [self.bilinear_layer[i](v[0])*v[1] for i,v in enumerate(combinations(feature_emb, 2))]
569
+ return torch.cat(bilinear_list, dim=1)
570
+
571
+
572
+ class MultiInterestSA(nn.Module):
573
+ """Multi-interest self-attention extractor from MIND (Li et al., 2019)."""
574
+
575
+ def __init__(self, embedding_dim, interest_num, hidden_dim=None):
576
+ super(MultiInterestSA, self).__init__()
577
+ self.embedding_dim = embedding_dim
578
+ self.interest_num = interest_num
579
+ if hidden_dim == None:
580
+ self.hidden_dim = self.embedding_dim * 4
581
+ self.W1 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.hidden_dim), requires_grad=True)
582
+ self.W2 = torch.nn.Parameter(torch.rand(self.hidden_dim, self.interest_num), requires_grad=True)
583
+ self.W3 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.embedding_dim), requires_grad=True)
584
+
585
+ def forward(self, seq_emb, mask=None):
586
+ H = torch.einsum('bse, ed -> bsd', seq_emb, self.W1).tanh()
587
+ if mask != None:
588
+ A = torch.einsum('bsd, dk -> bsk', H, self.W2) + -1.e9 * (1 - mask.float())
589
+ A = F.softmax(A, dim=1)
590
+ else:
591
+ A = F.softmax(torch.einsum('bsd, dk -> bsk', H, self.W2), dim=1)
592
+ A = A.permute(0, 2, 1)
593
+ multi_interest_emb = torch.matmul(A, seq_emb)
594
+ return multi_interest_emb
595
+
596
+
597
+ class CapsuleNetwork(nn.Module):
598
+ """Dynamic routing capsule network used in MIND (Li et al., 2019)."""
599
+
600
+ def __init__(self, embedding_dim, seq_len, bilinear_type=2, interest_num=4, routing_times=3, relu_layer=False):
601
+ super(CapsuleNetwork, self).__init__()
602
+ self.embedding_dim = embedding_dim # h
603
+ self.seq_len = seq_len # s
604
+ self.bilinear_type = bilinear_type
605
+ self.interest_num = interest_num
606
+ self.routing_times = routing_times
607
+
608
+ self.relu_layer = relu_layer
609
+ self.stop_grad = True
610
+ self.relu = nn.Sequential(nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), nn.ReLU())
611
+ if self.bilinear_type == 0: # MIND
612
+ self.linear = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
613
+ elif self.bilinear_type == 1:
614
+ self.linear = nn.Linear(self.embedding_dim, self.embedding_dim * self.interest_num, bias=False)
615
+ else:
616
+ self.w = nn.Parameter(torch.Tensor(1, self.seq_len, self.interest_num * self.embedding_dim, self.embedding_dim))
617
+ nn.init.xavier_uniform_(self.w)
618
+
619
+ def forward(self, item_eb, mask):
620
+ if self.bilinear_type == 0:
621
+ item_eb_hat = self.linear(item_eb)
622
+ item_eb_hat = item_eb_hat.repeat(1, 1, self.interest_num)
623
+ elif self.bilinear_type == 1:
624
+ item_eb_hat = self.linear(item_eb)
625
+ else:
626
+ u = torch.unsqueeze(item_eb, dim=2)
627
+ item_eb_hat = torch.sum(self.w[:, :self.seq_len, :, :] * u, dim=3)
628
+
629
+ item_eb_hat = torch.reshape(item_eb_hat, (-1, self.seq_len, self.interest_num, self.embedding_dim))
630
+ item_eb_hat = torch.transpose(item_eb_hat, 1, 2).contiguous()
631
+ item_eb_hat = torch.reshape(item_eb_hat, (-1, self.interest_num, self.seq_len, self.embedding_dim))
632
+
633
+ if self.stop_grad:
634
+ item_eb_hat_iter = item_eb_hat.detach()
635
+ else:
636
+ item_eb_hat_iter = item_eb_hat
637
+
638
+ if self.bilinear_type > 0:
639
+ capsule_weight = torch.zeros(item_eb_hat.shape[0],
640
+ self.interest_num,
641
+ self.seq_len,
642
+ device=item_eb.device,
643
+ requires_grad=False)
644
+ else:
645
+ capsule_weight = torch.randn(item_eb_hat.shape[0],
646
+ self.interest_num,
647
+ self.seq_len,
648
+ device=item_eb.device,
649
+ requires_grad=False)
650
+
651
+ for i in range(self.routing_times): # 动态路由传播3次
652
+ atten_mask = torch.unsqueeze(mask, 1).repeat(1, self.interest_num, 1)
653
+ paddings = torch.zeros_like(atten_mask, dtype=torch.float)
654
+
655
+ capsule_softmax_weight = F.softmax(capsule_weight, dim=-1)
656
+ capsule_softmax_weight = torch.where(torch.eq(atten_mask, 0), paddings, capsule_softmax_weight)
657
+ capsule_softmax_weight = torch.unsqueeze(capsule_softmax_weight, 2)
658
+
659
+ if i < 2:
660
+ interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat_iter)
661
+ cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
662
+ scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
663
+ interest_capsule = scalar_factor * interest_capsule
664
+
665
+ delta_weight = torch.matmul(item_eb_hat_iter, torch.transpose(interest_capsule, 2, 3).contiguous())
666
+ delta_weight = torch.reshape(delta_weight, (-1, self.interest_num, self.seq_len))
667
+ capsule_weight = capsule_weight + delta_weight
668
+ else:
669
+ interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat)
670
+ cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
671
+ scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
672
+ interest_capsule = scalar_factor * interest_capsule
673
+
674
+ interest_capsule = torch.reshape(interest_capsule, (-1, self.interest_num, self.embedding_dim))
675
+
676
+ if self.relu_layer:
677
+ interest_capsule = self.relu(interest_capsule)
678
+
679
+ return interest_capsule
680
+
681
+
682
+ class FFM(nn.Module):
683
+ """Field-aware Factorization Machine (Juan et al., 2016)."""
684
+
685
+ def __init__(self, num_fields, reduce_sum=True):
686
+ super().__init__()
687
+ self.num_fields = num_fields
688
+ self.reduce_sum = reduce_sum
689
+
690
+ def forward(self, x):
691
+ # compute (non-redundant) second order field-aware feature crossings
692
+ crossed_embeddings = []
693
+ for i in range(self.num_fields-1):
694
+ for j in range(i+1, self.num_fields):
695
+ crossed_embeddings.append(x[:, i, j, :] * x[:, j, i, :])
696
+ crossed_embeddings = torch.stack(crossed_embeddings, dim=1)
697
+
698
+ # if reduce_sum is true, the crossing operation is effectively inner product, other wise Hadamard-product
699
+ if self.reduce_sum:
700
+ crossed_embeddings = torch.sum(crossed_embeddings, dim=-1, keepdim=True)
701
+ return crossed_embeddings
702
+
703
+
704
+ class CEN(nn.Module):
705
+ """Field-attentive interaction network from FAT-DeepFFM (Wang et al., 2020)."""
706
+
707
+ def __init__(self, embed_dim, num_field_crosses, reduction_ratio):
708
+ super().__init__()
709
+
710
+ # convolution weight (Eq.7 FAT-DeepFFM)
711
+ self.u = torch.nn.Parameter(torch.rand(num_field_crosses, embed_dim), requires_grad=True)
712
+
713
+ # two FC layers that computes the field attention
714
+ self.mlp_att = MLP(num_field_crosses, dims=[num_field_crosses//reduction_ratio, num_field_crosses], output_layer=False, activation="relu")
715
+
716
+
717
+ def forward(self, em):
718
+ # compute descriptor vector (Eq.7 FAT-DeepFFM), output shape [batch_size, num_field_crosses]
719
+ d = F.relu((self.u.squeeze(0) * em).sum(-1))
720
+
721
+ # compute field attention (Eq.9), output shape [batch_size, num_field_crosses]
722
+ s = self.mlp_att(d)
723
+
724
+ # rescale original embedding with field attention (Eq.10), output shape [batch_size, num_field_crosses, embed_dim]
725
+ aem = s.unsqueeze(-1) * em
726
+ return aem.flatten(start_dim=1)
727
+
728
+
729
+ class MultiHeadSelfAttention(nn.Module):
730
+ """Multi-head self-attention layer from AutoInt (Song et al., 2019)."""
731
+
732
+ def __init__(self, embedding_dim, num_heads=2, dropout=0.0, use_residual=True):
733
+ super().__init__()
734
+ if embedding_dim % num_heads != 0:
735
+ raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by num_heads ({num_heads})")
736
+
737
+ self.embedding_dim = embedding_dim
738
+ self.num_heads = num_heads
739
+ self.head_dim = embedding_dim // num_heads
740
+ self.use_residual = use_residual
741
+
742
+ self.W_Q = nn.Linear(embedding_dim, embedding_dim, bias=False)
743
+ self.W_K = nn.Linear(embedding_dim, embedding_dim, bias=False)
744
+ self.W_V = nn.Linear(embedding_dim, embedding_dim, bias=False)
745
+
746
+ if self.use_residual:
747
+ self.W_Res = nn.Linear(embedding_dim, embedding_dim, bias=False)
748
+
749
+ self.dropout = nn.Dropout(dropout)
750
+
751
+ def forward(self, x):
752
+ """
753
+ Args:
754
+ x: [batch_size, num_fields, embedding_dim]
755
+ Returns:
756
+ output: [batch_size, num_fields, embedding_dim]
757
+ """
758
+ batch_size, num_fields, _ = x.shape
759
+
760
+ # Linear projections
761
+ Q = self.W_Q(x) # [batch_size, num_fields, embedding_dim]
762
+ K = self.W_K(x)
763
+ V = self.W_V(x)
764
+
765
+ # Split into multiple heads: [batch_size, num_heads, num_fields, head_dim]
766
+ Q = Q.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
767
+ K = K.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
768
+ V = V.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
769
+
770
+ # Attention scores
771
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
772
+ attention_weights = F.softmax(scores, dim=-1)
773
+ attention_weights = self.dropout(attention_weights)
774
+
775
+ # Apply attention to values
776
+ attention_output = torch.matmul(attention_weights, V) # [batch_size, num_heads, num_fields, head_dim]
777
+
778
+ # Concatenate heads
779
+ attention_output = attention_output.transpose(1, 2).contiguous()
780
+ attention_output = attention_output.view(batch_size, num_fields, self.embedding_dim)
781
+
782
+ # Residual connection
783
+ if self.use_residual:
784
+ output = attention_output + self.W_Res(x)
785
+ else:
786
+ output = attention_output
787
+
788
+ output = F.relu(output)
789
+
790
+ return output
791
+
792
+
793
+ class AttentionPoolingLayer(nn.Module):
794
+ """
795
+ Attention pooling layer for DIN/DIEN
796
+ Computes attention weights between query (candidate item) and keys (user behavior sequence)
797
+ """
798
+
799
+ def __init__(self, embedding_dim, hidden_units=[80, 40], activation='sigmoid', use_softmax=True):
800
+ super().__init__()
801
+ self.embedding_dim = embedding_dim
802
+ self.use_softmax = use_softmax
803
+
804
+ # Build attention network
805
+ # Input: [query, key, query-key, query*key] -> 4 * embedding_dim
806
+ input_dim = 4 * embedding_dim
807
+ layers = []
808
+
809
+ for hidden_unit in hidden_units:
810
+ layers.append(nn.Linear(input_dim, hidden_unit))
811
+ layers.append(activation_layer(activation))
812
+ input_dim = hidden_unit
813
+
814
+ layers.append(nn.Linear(input_dim, 1))
815
+ self.attention_net = nn.Sequential(*layers)
816
+
817
+ def forward(self, query, keys, keys_length=None, mask=None):
818
+ """
819
+ Args:
820
+ query: [batch_size, embedding_dim] - candidate item embedding
821
+ keys: [batch_size, seq_len, embedding_dim] - user behavior sequence
822
+ keys_length: [batch_size] - actual length of each sequence (optional)
823
+ mask: [batch_size, seq_len, 1] - mask for padding (optional)
824
+ Returns:
825
+ output: [batch_size, embedding_dim] - attention pooled representation
826
+ """
827
+ batch_size, seq_len, emb_dim = keys.shape
828
+
829
+ # Expand query to match sequence length: [batch_size, seq_len, embedding_dim]
830
+ query_expanded = query.unsqueeze(1).expand(-1, seq_len, -1)
831
+
832
+ # Compute attention features: [query, key, query-key, query*key]
833
+ attention_input = torch.cat([
834
+ query_expanded,
835
+ keys,
836
+ query_expanded - keys,
837
+ query_expanded * keys
838
+ ], dim=-1) # [batch_size, seq_len, 4*embedding_dim]
839
+
840
+ # Compute attention scores
841
+ attention_scores = self.attention_net(attention_input) # [batch_size, seq_len, 1]
842
+
843
+ # Apply mask if provided
844
+ if mask is not None:
845
+ attention_scores = attention_scores.masked_fill(mask == 0, -1e9)
846
+
847
+ # Apply softmax to get attention weights
848
+ if self.use_softmax:
849
+ attention_weights = F.softmax(attention_scores, dim=1) # [batch_size, seq_len, 1]
850
+ else:
851
+ attention_weights = attention_scores
852
+
853
+ # Weighted sum of keys
854
+ output = torch.sum(attention_weights * keys, dim=1) # [batch_size, embedding_dim]
855
+
856
+ return output
857
+
858
+
859
+ class DynamicGRU(nn.Module):
860
+ """Dynamic GRU unit with auxiliary loss path from DIEN (Zhou et al., 2019)."""
861
+ """
862
+ GRU with dynamic routing for DIEN
863
+ """
864
+
865
+ def __init__(self, input_size, hidden_size, bias=True):
866
+ super().__init__()
867
+ self.input_size = input_size
868
+ self.hidden_size = hidden_size
869
+
870
+ # GRU parameters
871
+ self.weight_ih = nn.Parameter(torch.randn(3 * hidden_size, input_size))
872
+ self.weight_hh = nn.Parameter(torch.randn(3 * hidden_size, hidden_size))
873
+ if bias:
874
+ self.bias_ih = nn.Parameter(torch.randn(3 * hidden_size))
875
+ self.bias_hh = nn.Parameter(torch.randn(3 * hidden_size))
876
+ else:
877
+ self.register_parameter('bias_ih', None)
878
+ self.register_parameter('bias_hh', None)
879
+
880
+ self.reset_parameters()
881
+
882
+ def reset_parameters(self):
883
+ std = 1.0 / (self.hidden_size) ** 0.5
884
+ for weight in self.parameters():
885
+ weight.data.uniform_(-std, std)
886
+
887
+ def forward(self, x, att_scores=None):
888
+ """
889
+ Args:
890
+ x: [batch_size, seq_len, input_size]
891
+ att_scores: [batch_size, seq_len] - attention scores for auxiliary loss
892
+ Returns:
893
+ output: [batch_size, seq_len, hidden_size]
894
+ hidden: [batch_size, hidden_size] - final hidden state
895
+ """
896
+ batch_size, seq_len, _ = x.shape
897
+
898
+ # Initialize hidden state
899
+ h = torch.zeros(batch_size, self.hidden_size, device=x.device)
900
+
901
+ outputs = []
902
+ for t in range(seq_len):
903
+ x_t = x[:, t, :] # [batch_size, input_size]
904
+
905
+ # GRU computation
906
+ gi = F.linear(x_t, self.weight_ih, self.bias_ih)
907
+ gh = F.linear(h, self.weight_hh, self.bias_hh)
908
+ i_r, i_i, i_n = gi.chunk(3, 1)
909
+ h_r, h_i, h_n = gh.chunk(3, 1)
910
+
911
+ resetgate = torch.sigmoid(i_r + h_r)
912
+ inputgate = torch.sigmoid(i_i + h_i)
913
+ newgate = torch.tanh(i_n + resetgate * h_n)
914
+ h = newgate + inputgate * (h - newgate)
915
+
916
+ outputs.append(h.unsqueeze(1))
917
+
918
+ output = torch.cat(outputs, dim=1) # [batch_size, seq_len, hidden_size]
919
+
920
+ return output, h
921
+
922
+
923
+ class AUGRU(nn.Module):
924
+ """Attention-aware GRU update gate used in DIEN (Zhou et al., 2019)."""
925
+ """
926
+ Attention-based GRU for DIEN
927
+ Uses attention scores to weight the update of hidden states
928
+ """
929
+
930
+ def __init__(self, input_size, hidden_size, bias=True):
931
+ super().__init__()
932
+ self.input_size = input_size
933
+ self.hidden_size = hidden_size
934
+
935
+ self.weight_ih = nn.Parameter(torch.randn(3 * hidden_size, input_size))
936
+ self.weight_hh = nn.Parameter(torch.randn(3 * hidden_size, hidden_size))
937
+ if bias:
938
+ self.bias_ih = nn.Parameter(torch.randn(3 * hidden_size))
939
+ self.bias_hh = nn.Parameter(torch.randn(3 * hidden_size))
940
+ else:
941
+ self.register_parameter('bias_ih', None)
942
+ self.register_parameter('bias_hh', None)
943
+
944
+ self.reset_parameters()
945
+
946
+ def reset_parameters(self):
947
+ std = 1.0 / (self.hidden_size) ** 0.5
948
+ for weight in self.parameters():
949
+ weight.data.uniform_(-std, std)
950
+
951
+ def forward(self, x, att_scores):
952
+ """
953
+ Args:
954
+ x: [batch_size, seq_len, input_size]
955
+ att_scores: [batch_size, seq_len, 1] - attention scores
956
+ Returns:
957
+ output: [batch_size, seq_len, hidden_size]
958
+ hidden: [batch_size, hidden_size] - final hidden state
959
+ """
960
+ batch_size, seq_len, _ = x.shape
961
+
962
+ h = torch.zeros(batch_size, self.hidden_size, device=x.device)
963
+
964
+ outputs = []
965
+ for t in range(seq_len):
966
+ x_t = x[:, t, :] # [batch_size, input_size]
967
+ att_t = att_scores[:, t, :] # [batch_size, 1]
968
+
969
+ gi = F.linear(x_t, self.weight_ih, self.bias_ih)
970
+ gh = F.linear(h, self.weight_hh, self.bias_hh)
971
+ i_r, i_i, i_n = gi.chunk(3, 1)
972
+ h_r, h_i, h_n = gh.chunk(3, 1)
973
+
974
+ resetgate = torch.sigmoid(i_r + h_r)
975
+ inputgate = torch.sigmoid(i_i + h_i)
976
+ newgate = torch.tanh(i_n + resetgate * h_n)
977
+
978
+ # Use attention score to control update
979
+ h = (1 - att_t) * h + att_t * newgate
980
+
981
+ outputs.append(h.unsqueeze(1))
982
+
983
+ output = torch.cat(outputs, dim=1)
984
+
985
+ return output, h