torch-rechub 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. torch_rechub/__init__.py +14 -0
  2. torch_rechub/basic/activation.py +54 -54
  3. torch_rechub/basic/callback.py +33 -33
  4. torch_rechub/basic/features.py +87 -94
  5. torch_rechub/basic/initializers.py +92 -92
  6. torch_rechub/basic/layers.py +994 -720
  7. torch_rechub/basic/loss_func.py +223 -34
  8. torch_rechub/basic/metaoptimizer.py +76 -72
  9. torch_rechub/basic/metric.py +251 -250
  10. torch_rechub/models/generative/__init__.py +6 -0
  11. torch_rechub/models/generative/hllm.py +249 -0
  12. torch_rechub/models/generative/hstu.py +189 -0
  13. torch_rechub/models/matching/__init__.py +13 -11
  14. torch_rechub/models/matching/comirec.py +193 -188
  15. torch_rechub/models/matching/dssm.py +72 -66
  16. torch_rechub/models/matching/dssm_facebook.py +77 -79
  17. torch_rechub/models/matching/dssm_senet.py +28 -16
  18. torch_rechub/models/matching/gru4rec.py +85 -87
  19. torch_rechub/models/matching/mind.py +103 -101
  20. torch_rechub/models/matching/narm.py +82 -76
  21. torch_rechub/models/matching/sasrec.py +143 -140
  22. torch_rechub/models/matching/sine.py +148 -151
  23. torch_rechub/models/matching/stamp.py +81 -83
  24. torch_rechub/models/matching/youtube_dnn.py +75 -71
  25. torch_rechub/models/matching/youtube_sbc.py +98 -98
  26. torch_rechub/models/multi_task/__init__.py +7 -5
  27. torch_rechub/models/multi_task/aitm.py +83 -84
  28. torch_rechub/models/multi_task/esmm.py +56 -55
  29. torch_rechub/models/multi_task/mmoe.py +58 -58
  30. torch_rechub/models/multi_task/ple.py +116 -130
  31. torch_rechub/models/multi_task/shared_bottom.py +45 -45
  32. torch_rechub/models/ranking/__init__.py +14 -11
  33. torch_rechub/models/ranking/afm.py +65 -63
  34. torch_rechub/models/ranking/autoint.py +102 -0
  35. torch_rechub/models/ranking/bst.py +61 -63
  36. torch_rechub/models/ranking/dcn.py +38 -38
  37. torch_rechub/models/ranking/dcn_v2.py +59 -69
  38. torch_rechub/models/ranking/deepffm.py +131 -123
  39. torch_rechub/models/ranking/deepfm.py +43 -42
  40. torch_rechub/models/ranking/dien.py +191 -191
  41. torch_rechub/models/ranking/din.py +93 -91
  42. torch_rechub/models/ranking/edcn.py +101 -117
  43. torch_rechub/models/ranking/fibinet.py +42 -50
  44. torch_rechub/models/ranking/widedeep.py +41 -41
  45. torch_rechub/trainers/__init__.py +4 -3
  46. torch_rechub/trainers/ctr_trainer.py +288 -128
  47. torch_rechub/trainers/match_trainer.py +336 -170
  48. torch_rechub/trainers/matching.md +3 -0
  49. torch_rechub/trainers/mtl_trainer.py +356 -207
  50. torch_rechub/trainers/seq_trainer.py +427 -0
  51. torch_rechub/utils/data.py +492 -360
  52. torch_rechub/utils/hstu_utils.py +198 -0
  53. torch_rechub/utils/match.py +457 -274
  54. torch_rechub/utils/model_utils.py +233 -0
  55. torch_rechub/utils/mtl.py +136 -126
  56. torch_rechub/utils/onnx_export.py +220 -0
  57. torch_rechub/utils/visualization.py +271 -0
  58. torch_rechub-0.0.5.dist-info/METADATA +402 -0
  59. torch_rechub-0.0.5.dist-info/RECORD +64 -0
  60. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info}/WHEEL +1 -2
  61. {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.5.dist-info/licenses}/LICENSE +21 -21
  62. torch_rechub-0.0.3.dist-info/METADATA +0 -177
  63. torch_rechub-0.0.3.dist-info/RECORD +0 -55
  64. torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
@@ -1,720 +1,994 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from itertools import combinations
5
- from .activation import activation_layer
6
- from .features import DenseFeature, SparseFeature, SequenceFeature
7
-
8
-
9
- class PredictionLayer(nn.Module):
10
- """Prediction Layer.
11
-
12
- Args:
13
- task_type (str): if `task_type='classification'`, then return sigmoid(x),
14
- change the input logits to probability. if`task_type='regression'`, then return x.
15
- """
16
-
17
- def __init__(self, task_type='classification'):
18
- super(PredictionLayer, self).__init__()
19
- if task_type not in ["classification", "regression"]:
20
- raise ValueError("task_type must be classification or regression")
21
- self.task_type = task_type
22
-
23
- def forward(self, x):
24
- if self.task_type == "classification":
25
- x = torch.sigmoid(x)
26
- return x
27
-
28
-
29
- class EmbeddingLayer(nn.Module):
30
- """General Embedding Layer.
31
- We save all the feature embeddings in embed_dict: `{feature_name : embedding table}`.
32
-
33
-
34
- Args:
35
- features (list): the list of `Feature Class`. It is means all the features which we want to create a embedding table.
36
-
37
- Shape:
38
- - Input:
39
- x (dict): {feature_name: feature_value}, sequence feature value is a 2D tensor with shape:`(batch_size, seq_len)`,\
40
- sparse/dense feature value is a 1D tensor with shape `(batch_size)`.
41
- features (list): the list of `Feature Class`. It is means the current features which we want to do embedding lookup.
42
- squeeze_dim (bool): whether to squeeze dim of output (default = `False`).
43
- - Output:
44
- - if input Dense: `(batch_size, num_features_dense)`.
45
- - if input Sparse: `(batch_size, num_features, embed_dim)` or `(batch_size, num_features * embed_dim)`.
46
- - if input Sequence: same with input sparse or `(batch_size, num_features_seq, seq_length, embed_dim)` when `pooling=="concat"`.
47
- - if input Dense and Sparse/Sequence: `(batch_size, num_features_sparse * embed_dim)`. Note we must squeeze_dim for concat dense value with sparse embedding.
48
- """
49
-
50
- def __init__(self, features):
51
- super().__init__()
52
- self.features = features
53
- self.embed_dict = nn.ModuleDict()
54
- self.n_dense = 0
55
-
56
- for fea in features:
57
- if fea.name in self.embed_dict: #exist
58
- continue
59
- if isinstance(fea, SparseFeature) and fea.shared_with == None:
60
- self.embed_dict[fea.name] = fea.get_embedding_layer()
61
- elif isinstance(fea, SequenceFeature) and fea.shared_with == None:
62
- self.embed_dict[fea.name] = fea.get_embedding_layer()
63
- elif isinstance(fea, DenseFeature):
64
- self.n_dense += 1
65
-
66
- def forward(self, x, features, squeeze_dim=False):
67
- sparse_emb, dense_values = [], []
68
- sparse_exists, dense_exists = False, False
69
- for fea in features:
70
- if isinstance(fea, SparseFeature):
71
- if fea.shared_with == None:
72
- sparse_emb.append(self.embed_dict[fea.name](x[fea.name].long()).unsqueeze(1))
73
- else:
74
- sparse_emb.append(self.embed_dict[fea.shared_with](x[fea.name].long()).unsqueeze(1))
75
- elif isinstance(fea, SequenceFeature):
76
- if fea.pooling == "sum":
77
- pooling_layer = SumPooling()
78
- elif fea.pooling == "mean":
79
- pooling_layer = AveragePooling()
80
- elif fea.pooling == "concat":
81
- pooling_layer = ConcatPooling()
82
- else:
83
- raise ValueError("Sequence pooling method supports only pooling in %s, got %s." %
84
- (["sum", "mean"], fea.pooling))
85
- fea_mask = InputMask()(x, fea)
86
- if fea.shared_with == None:
87
- sparse_emb.append(pooling_layer(self.embed_dict[fea.name](x[fea.name].long()), fea_mask).unsqueeze(1))
88
- else:
89
- sparse_emb.append(pooling_layer(self.embed_dict[fea.shared_with](x[fea.name].long()), fea_mask).unsqueeze(1)) #shared specific sparse feature embedding
90
- else:
91
- dense_values.append(x[fea.name].float() if x[fea.name].float().dim() > 1 else x[fea.name].float().unsqueeze(1)) #.unsqueeze(1).unsqueeze(1)
92
-
93
- if len(dense_values) > 0:
94
- dense_exists = True
95
- dense_values = torch.cat(dense_values, dim=1)
96
- if len(sparse_emb) > 0:
97
- sparse_exists = True
98
- # TODO: support concat dynamic embed_dim in dim 2
99
- sparse_emb = torch.cat(sparse_emb, dim=1) #[batch_size, num_features, embed_dim]
100
-
101
- if squeeze_dim: #Note: if the emb_dim of sparse features is different, we must squeeze_dim
102
- if dense_exists and not sparse_exists: #only input dense features
103
- return dense_values
104
- elif not dense_exists and sparse_exists:
105
- return sparse_emb.flatten(start_dim=1) #squeeze dim to : [batch_size, num_features*embed_dim]
106
- elif dense_exists and sparse_exists:
107
- return torch.cat((sparse_emb.flatten(start_dim=1), dense_values),
108
- dim=1) #concat dense value with sparse embedding
109
- else:
110
- raise ValueError("The input features can note be empty")
111
- else:
112
- if sparse_exists:
113
- return sparse_emb #[batch_size, num_features, embed_dim]
114
- else:
115
- raise ValueError(
116
- "If keep the original shape:[batch_size, num_features, embed_dim], expected %s in feature list, got %s" %
117
- ("SparseFeatures", features))
118
-
119
-
120
- class InputMask(nn.Module):
121
- """Return inputs mask from given features
122
-
123
- Shape:
124
- - Input:
125
- x (dict): {feature_name: feature_value}, sequence feature value is a 2D tensor with shape:`(batch_size, seq_len)`,\
126
- sparse/dense feature value is a 1D tensor with shape `(batch_size)`.
127
- features (list or SparseFeature or SequenceFeature): Note that the elements in features are either all instances of SparseFeature or all instances of SequenceFeature.
128
- - Output:
129
- - if input Sparse: `(batch_size, num_features)`
130
- - if input Sequence: `(batch_size, num_features_seq, seq_length)`
131
- """
132
-
133
- def __init__(self):
134
- super().__init__()
135
-
136
- def forward(self, x, features):
137
- mask = []
138
- if not isinstance(features, list):
139
- features = [features]
140
- for fea in features:
141
- if isinstance(fea, SparseFeature) or isinstance(fea, SequenceFeature):
142
- if fea.padding_idx != None:
143
- fea_mask = x[fea.name].long() != fea.padding_idx
144
- else:
145
- fea_mask = x[fea.name].long() != -1
146
- mask.append(fea_mask.unsqueeze(1).float())
147
- else:
148
- raise ValueError("Only SparseFeature or SequenceFeature support to get mask.")
149
- return torch.cat(mask, dim=1)
150
-
151
-
152
- class LR(nn.Module):
153
- """Logistic Regression Module. It is the one Non-linear
154
- transformation for input feature.
155
-
156
- Args:
157
- input_dim (int): input size of Linear module.
158
- sigmoid (bool): whether to add sigmoid function before output.
159
-
160
- Shape:
161
- - Input: `(batch_size, input_dim)`
162
- - Output: `(batch_size, 1)`
163
- """
164
-
165
- def __init__(self, input_dim, sigmoid=False):
166
- super().__init__()
167
- self.sigmoid = sigmoid
168
- self.fc = nn.Linear(input_dim, 1, bias=True)
169
-
170
- def forward(self, x):
171
- if self.sigmoid:
172
- return torch.sigmoid(self.fc(x))
173
- else:
174
- return self.fc(x)
175
-
176
-
177
- class ConcatPooling(nn.Module):
178
- """Keep the origin sequence embedding shape
179
-
180
- Shape:
181
- - Input: `(batch_size, seq_length, embed_dim)`
182
- - Output: `(batch_size, seq_length, embed_dim)`
183
- """
184
-
185
- def __init__(self):
186
- super().__init__()
187
-
188
- def forward(self, x, mask=None):
189
- return x
190
-
191
-
192
- class AveragePooling(nn.Module):
193
- """Pooling the sequence embedding matrix by `mean`.
194
-
195
- Shape:
196
- - Input
197
- x: `(batch_size, seq_length, embed_dim)`
198
- mask: `(batch_size, 1, seq_length)`
199
- - Output: `(batch_size, embed_dim)`
200
- """
201
-
202
- def __init__(self):
203
- super().__init__()
204
-
205
- def forward(self, x, mask=None):
206
- if mask == None:
207
- return torch.mean(x, dim=1)
208
- else:
209
- sum_pooling_matrix = torch.bmm(mask, x).squeeze(1)
210
- non_padding_length = mask.sum(dim=-1)
211
- return sum_pooling_matrix / (non_padding_length.float() + 1e-16)
212
-
213
-
214
- class SumPooling(nn.Module):
215
- """Pooling the sequence embedding matrix by `sum`.
216
-
217
- Shape:
218
- - Input
219
- x: `(batch_size, seq_length, embed_dim)`
220
- mask: `(batch_size, 1, seq_length)`
221
- - Output: `(batch_size, embed_dim)`
222
- """
223
-
224
- def __init__(self):
225
- super().__init__()
226
-
227
- def forward(self, x, mask=None):
228
- if mask == None:
229
- return torch.sum(x, dim=1)
230
- else:
231
- return torch.bmm(mask, x).squeeze(1)
232
-
233
-
234
- class MLP(nn.Module):
235
- """Multi Layer Perceptron Module, it is the most widely used module for
236
- learning feature. Note we default add `BatchNorm1d` and `Activation`
237
- `Dropout` for each `Linear` Module.
238
-
239
- Args:
240
- input dim (int): input size of the first Linear Layer.
241
- output_layer (bool): whether this MLP module is the output layer. If `True`, then append one Linear(*,1) module.
242
- dims (list): output size of Linear Layer (default=[]).
243
- dropout (float): probability of an element to be zeroed (default = 0.5).
244
- activation (str): the activation function, support `[sigmoid, relu, prelu, dice, softmax]` (default='relu').
245
-
246
- Shape:
247
- - Input: `(batch_size, input_dim)`
248
- - Output: `(batch_size, 1)` or `(batch_size, dims[-1])`
249
- """
250
-
251
- def __init__(self, input_dim, output_layer=True, dims=None, dropout=0, activation="relu"):
252
- super().__init__()
253
- if dims is None:
254
- dims = []
255
- layers = list()
256
- for i_dim in dims:
257
- layers.append(nn.Linear(input_dim, i_dim))
258
- layers.append(nn.BatchNorm1d(i_dim))
259
- layers.append(activation_layer(activation))
260
- layers.append(nn.Dropout(p=dropout))
261
- input_dim = i_dim
262
- if output_layer:
263
- layers.append(nn.Linear(input_dim, 1))
264
- self.mlp = nn.Sequential(*layers)
265
-
266
- def forward(self, x):
267
- return self.mlp(x)
268
-
269
-
270
- class FM(nn.Module):
271
- """The Factorization Machine module, mentioned in the `DeepFM paper
272
- <https://arxiv.org/pdf/1703.04247.pdf>`. It is used to learn 2nd-order
273
- feature interactions.
274
-
275
- Args:
276
- reduce_sum (bool): whether to sum in embed_dim (default = `True`).
277
-
278
- Shape:
279
- - Input: `(batch_size, num_features, embed_dim)`
280
- - Output: `(batch_size, 1)`` or ``(batch_size, embed_dim)`
281
- """
282
-
283
- def __init__(self, reduce_sum=True):
284
- super().__init__()
285
- self.reduce_sum = reduce_sum
286
-
287
- def forward(self, x):
288
- square_of_sum = torch.sum(x, dim=1)**2
289
- sum_of_square = torch.sum(x**2, dim=1)
290
- ix = square_of_sum - sum_of_square
291
- if self.reduce_sum:
292
- ix = torch.sum(ix, dim=1, keepdim=True)
293
- return 0.5 * ix
294
-
295
-
296
- class CIN(nn.Module):
297
- """Compressed Interaction Network
298
-
299
- Args:
300
- input_dim (int): input dim of input tensor.
301
- cin_size (list[int]): out channels of Conv1d.
302
-
303
- Shape:
304
- - Input: `(batch_size, num_features, embed_dim)`
305
- - Output: `(batch_size, 1)`
306
- """
307
-
308
- def __init__(self, input_dim, cin_size, split_half=True):
309
- super().__init__()
310
- self.num_layers = len(cin_size)
311
- self.split_half = split_half
312
- self.conv_layers = torch.nn.ModuleList()
313
- prev_dim, fc_input_dim = input_dim, 0
314
- for i in range(self.num_layers):
315
- cross_layer_size = cin_size[i]
316
- self.conv_layers.append(torch.nn.Conv1d(input_dim * prev_dim, cross_layer_size, 1, stride=1, dilation=1, bias=True))
317
- if self.split_half and i != self.num_layers - 1:
318
- cross_layer_size //= 2
319
- prev_dim = cross_layer_size
320
- fc_input_dim += prev_dim
321
- self.fc = torch.nn.Linear(fc_input_dim, 1)
322
-
323
- def forward(self, x):
324
- xs = list()
325
- x0, h = x.unsqueeze(2), x
326
- for i in range(self.num_layers):
327
- x = x0 * h.unsqueeze(1)
328
- batch_size, f0_dim, fin_dim, embed_dim = x.shape
329
- x = x.view(batch_size, f0_dim * fin_dim, embed_dim)
330
- x = F.relu(self.conv_layers[i](x))
331
- if self.split_half and i != self.num_layers - 1:
332
- x, h = torch.split(x, x.shape[1] // 2, dim=1)
333
- else:
334
- h = x
335
- xs.append(x)
336
- return self.fc(torch.sum(torch.cat(xs, dim=1), 2))
337
-
338
- class CrossLayer(nn.Module):
339
- """
340
- Cross layer.
341
- Args:
342
- input_dim (int): input dim of input tensor
343
- """
344
- def __init__(self, input_dim):
345
- super(CrossLayer, self).__init__()
346
- self.w = torch.nn.Linear(input_dim, 1, bias=False)
347
- self.b = torch.nn.Parameter(torch.zeros(input_dim))
348
-
349
- def forward(self, x_0, x_i):
350
- x = self.w(x_i) * x_0 + self.b
351
- return x
352
-
353
-
354
- class CrossNetwork(nn.Module):
355
- """CrossNetwork mentioned in the DCN paper.
356
-
357
- Args:
358
- input_dim (int): input dim of input tensor
359
-
360
- Shape:
361
- - Input: `(batch_size, *)`
362
- - Output: `(batch_size, *)`
363
-
364
- """
365
-
366
- def __init__(self, input_dim, num_layers):
367
- super().__init__()
368
- self.num_layers = num_layers
369
- self.w = torch.nn.ModuleList([torch.nn.Linear(input_dim, 1, bias=False) for _ in range(num_layers)])
370
- self.b = torch.nn.ParameterList([torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)])
371
-
372
- def forward(self, x):
373
- """
374
- :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
375
- """
376
- x0 = x
377
- for i in range(self.num_layers):
378
- xw = self.w[i](x)
379
- x = x0 * xw + self.b[i] + x
380
- return x
381
-
382
- class CrossNetV2(nn.Module):
383
- def __init__(self, input_dim, num_layers):
384
- super().__init__()
385
- self.num_layers = num_layers
386
- self.w = torch.nn.ModuleList([torch.nn.Linear(input_dim, input_dim, bias=False) for _ in range(num_layers)])
387
- self.b = torch.nn.ParameterList([torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)])
388
-
389
-
390
- def forward(self, x):
391
- x0 = x
392
- for i in range(self.num_layers):
393
- x =x0*self.w[i](x) + self.b[i] + x
394
- return x
395
-
396
- class CrossNetMix(nn.Module):
397
- """ CrossNetMix improves CrossNetwork by:
398
- 1. add MOE to learn feature interactions in different subspaces
399
- 2. add nonlinear transformations in low-dimensional space
400
- :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
401
- """
402
-
403
- def __init__(self, input_dim, num_layers=2, low_rank=32, num_experts=4):
404
- super(CrossNetMix, self).__init__()
405
- self.num_layers = num_layers
406
- self.num_experts = num_experts
407
-
408
- # U: (input_dim, low_rank)
409
- self.u_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
410
- torch.empty(num_experts, input_dim, low_rank))) for i in range(self.num_layers)])
411
- # V: (input_dim, low_rank)
412
- self.v_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
413
- torch.empty(num_experts, input_dim, low_rank))) for i in range(self.num_layers)])
414
- # C: (low_rank, low_rank)
415
- self.c_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
416
- torch.empty(num_experts, low_rank, low_rank))) for i in range(self.num_layers)])
417
- self.gating = nn.ModuleList([nn.Linear(input_dim, 1, bias=False) for i in range(self.num_experts)])
418
-
419
- self.bias = torch.nn.ParameterList([nn.Parameter(nn.init.zeros_(
420
- torch.empty(input_dim, 1))) for i in range(self.num_layers)])
421
-
422
- def forward(self, x):
423
- x_0 = x.unsqueeze(2) # (bs, in_features, 1)
424
- x_l = x_0
425
- for i in range(self.num_layers):
426
- output_of_experts = []
427
- gating_score_experts = []
428
- for expert_id in range(self.num_experts):
429
- # (1) G(x_l)
430
- # compute the gating score by x_l
431
- gating_score_experts.append(self.gating[expert_id](x_l.squeeze(2)))
432
-
433
- # (2) E(x_l)
434
- # project the input x_l to $\mathbb{R}^{r}$
435
- v_x = torch.matmul(self.v_list[i][expert_id].t(), x_l) # (bs, low_rank, 1)
436
-
437
- # nonlinear activation in low rank space
438
- v_x = torch.tanh(v_x)
439
- v_x = torch.matmul(self.c_list[i][expert_id], v_x)
440
- v_x = torch.tanh(v_x)
441
-
442
- # project back to $\mathbb{R}^{d}$
443
- uv_x = torch.matmul(self.u_list[i][expert_id], v_x) # (bs, in_features, 1)
444
-
445
- dot_ = uv_x + self.bias[i]
446
- dot_ = x_0 * dot_ # Hadamard-product
447
-
448
- output_of_experts.append(dot_.squeeze(2))
449
-
450
- # (3) mixture of low-rank experts
451
- output_of_experts = torch.stack(output_of_experts, 2) # (bs, in_features, num_experts)
452
- gating_score_experts = torch.stack(gating_score_experts, 1) # (bs, num_experts, 1)
453
- moe_out = torch.matmul(output_of_experts, gating_score_experts.softmax(1))
454
- x_l = moe_out + x_l # (bs, in_features, 1)
455
-
456
- x_l = x_l.squeeze() # (bs, in_features)
457
- return x_l
458
-
459
- class SENETLayer(nn.Module):
460
- """
461
- A weighted feature gating system in the SENet paper
462
- Args:
463
- num_fields (int): number of feature fields
464
-
465
- Shape:
466
- - num_fields: `(batch_size, *)`
467
- - Output: `(batch_size, *)`
468
- """
469
- def __init__(self, num_fields, reduction_ratio=3):
470
- super(SENETLayer, self).__init__()
471
- reduced_size = max(1, int(num_fields/ reduction_ratio))
472
- self.mlp = nn.Sequential(nn.Linear(num_fields, reduced_size, bias=False),
473
- nn.ReLU(),
474
- nn.Linear(reduced_size, num_fields, bias=False),
475
- nn.ReLU())
476
- def forward(self, x):
477
- z = torch.mean(x, dim=-1, out=None)
478
- a = self.mlp(z)
479
- v = x*a.unsqueeze(-1)
480
- return v
481
-
482
- class BiLinearInteractionLayer(nn.Module):
483
- """
484
- Bilinear feature interaction module, which is an improved model of the FFM model
485
- Args:
486
- num_fields (int): number of feature fields
487
- bilinear_type(str): the type bilinear interaction function
488
- Shape:
489
- - num_fields: `(batch_size, *)`
490
- - Output: `(batch_size, *)`
491
- """
492
- def __init__(self, input_dim, num_fields, bilinear_type = "field_interaction"):
493
- super(BiLinearInteractionLayer, self).__init__()
494
- self.bilinear_type = bilinear_type
495
- if self.bilinear_type == "field_all":
496
- self.bilinear_layer = nn.Linear(input_dim, input_dim, bias=False)
497
- elif self.bilinear_type == "field_each":
498
- self.bilinear_layer = nn.ModuleList([nn.Linear(input_dim, input_dim, bias=False) for i in range(num_fields)])
499
- elif self.bilinear_type == "field_interaction":
500
- self.bilinear_layer = nn.ModuleList([nn.Linear(input_dim, input_dim, bias=False) for i,j in combinations(range(num_fields), 2)])
501
- else:
502
- raise NotImplementedError()
503
-
504
- def forward(self, x):
505
- feature_emb = torch.split(x, 1, dim=1)
506
- if self.bilinear_type == "field_all":
507
- bilinear_list = [self.bilinear_layer(v_i)*v_j for v_i, v_j in combinations(feature_emb, 2)]
508
- elif self.bilinear_type == "field_each":
509
- bilinear_list = [self.bilinear_layer[i](feature_emb[i])*feature_emb[j] for i,j in combinations(range(len(feature_emb)), 2)]
510
- elif self.bilinear_type == "field_interaction":
511
- bilinear_list = [self.bilinear_layer[i](v[0])*v[1] for i,v in enumerate(combinations(feature_emb, 2))]
512
- return torch.cat(bilinear_list, dim=1)
513
-
514
-
515
-
516
-
517
- class MultiInterestSA(nn.Module):
518
- """MultiInterest Attention mentioned in the Comirec paper.
519
-
520
- Args:
521
- embedding_dim (int): embedding dim of item embedding
522
- interest_num (int): num of interest
523
- hidden_dim (int): hidden dim
524
-
525
- Shape:
526
- - Input: seq_emb : (batch,seq,emb)
527
- mask : (batch,seq,1)
528
- - Output: `(batch_size, interest_num, embedding_dim)`
529
-
530
- """
531
-
532
- def __init__(self, embedding_dim, interest_num, hidden_dim=None):
533
- super(MultiInterestSA, self).__init__()
534
- self.embedding_dim = embedding_dim
535
- self.interest_num = interest_num
536
- if hidden_dim == None:
537
- self.hidden_dim = self.embedding_dim * 4
538
- self.W1 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.hidden_dim), requires_grad=True)
539
- self.W2 = torch.nn.Parameter(torch.rand(self.hidden_dim, self.interest_num), requires_grad=True)
540
- self.W3 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.embedding_dim), requires_grad=True)
541
-
542
- def forward(self, seq_emb, mask=None):
543
- H = torch.einsum('bse, ed -> bsd', seq_emb, self.W1).tanh()
544
- if mask != None:
545
- A = torch.einsum('bsd, dk -> bsk', H, self.W2) + -1.e9 * (1 - mask.float())
546
- A = F.softmax(A, dim=1)
547
- else:
548
- A = F.softmax(torch.einsum('bsd, dk -> bsk', H, self.W2), dim=1)
549
- A = A.permute(0, 2, 1)
550
- multi_interest_emb = torch.matmul(A, seq_emb)
551
- return multi_interest_emb
552
-
553
-
554
- class CapsuleNetwork(nn.Module):
555
- """CapsuleNetwork mentioned in the Comirec and MIND paper.
556
-
557
- Args:
558
- hidden_size (int): embedding dim of item embedding
559
- seq_len (int): length of the item sequence
560
- bilinear_type (int): 0 for MIND, 2 for ComirecDR
561
- interest_num (int): num of interest
562
- routing_times (int): routing times
563
-
564
- Shape:
565
- - Input: seq_emb : (batch,seq,emb)
566
- mask : (batch,seq,1)
567
- - Output: `(batch_size, interest_num, embedding_dim)`
568
-
569
- """
570
-
571
- def __init__(self, embedding_dim, seq_len, bilinear_type=2, interest_num=4, routing_times=3, relu_layer=False):
572
- super(CapsuleNetwork, self).__init__()
573
- self.embedding_dim = embedding_dim # h
574
- self.seq_len = seq_len # s
575
- self.bilinear_type = bilinear_type
576
- self.interest_num = interest_num
577
- self.routing_times = routing_times
578
-
579
- self.relu_layer = relu_layer
580
- self.stop_grad = True
581
- self.relu = nn.Sequential(nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), nn.ReLU())
582
- if self.bilinear_type == 0: # MIND
583
- self.linear = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
584
- elif self.bilinear_type == 1:
585
- self.linear = nn.Linear(self.embedding_dim, self.embedding_dim * self.interest_num, bias=False)
586
- else:
587
- self.w = nn.Parameter(torch.Tensor(1, self.seq_len, self.interest_num * self.embedding_dim, self.embedding_dim))
588
-
589
- def forward(self, item_eb, mask):
590
- if self.bilinear_type == 0:
591
- item_eb_hat = self.linear(item_eb)
592
- item_eb_hat = item_eb_hat.repeat(1, 1, self.interest_num)
593
- elif self.bilinear_type == 1:
594
- item_eb_hat = self.linear(item_eb)
595
- else:
596
- u = torch.unsqueeze(item_eb, dim=2)
597
- item_eb_hat = torch.sum(self.w[:, :self.seq_len, :, :] * u, dim=3)
598
-
599
- item_eb_hat = torch.reshape(item_eb_hat, (-1, self.seq_len, self.interest_num, self.embedding_dim))
600
- item_eb_hat = torch.transpose(item_eb_hat, 1, 2).contiguous()
601
- item_eb_hat = torch.reshape(item_eb_hat, (-1, self.interest_num, self.seq_len, self.embedding_dim))
602
-
603
- if self.stop_grad:
604
- item_eb_hat_iter = item_eb_hat.detach()
605
- else:
606
- item_eb_hat_iter = item_eb_hat
607
-
608
- if self.bilinear_type > 0:
609
- capsule_weight = torch.zeros(item_eb_hat.shape[0],
610
- self.interest_num,
611
- self.seq_len,
612
- device=item_eb.device,
613
- requires_grad=False)
614
- else:
615
- capsule_weight = torch.randn(item_eb_hat.shape[0],
616
- self.interest_num,
617
- self.seq_len,
618
- device=item_eb.device,
619
- requires_grad=False)
620
-
621
- for i in range(self.routing_times): # 动态路由传播3次
622
- atten_mask = torch.unsqueeze(mask, 1).repeat(1, self.interest_num, 1)
623
- paddings = torch.zeros_like(atten_mask, dtype=torch.float)
624
-
625
- capsule_softmax_weight = F.softmax(capsule_weight, dim=-1)
626
- capsule_softmax_weight = torch.where(torch.eq(atten_mask, 0), paddings, capsule_softmax_weight)
627
- capsule_softmax_weight = torch.unsqueeze(capsule_softmax_weight, 2)
628
-
629
- if i < 2:
630
- interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat_iter)
631
- cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
632
- scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
633
- interest_capsule = scalar_factor * interest_capsule
634
-
635
- delta_weight = torch.matmul(item_eb_hat_iter, torch.transpose(interest_capsule, 2, 3).contiguous())
636
- delta_weight = torch.reshape(delta_weight, (-1, self.interest_num, self.seq_len))
637
- capsule_weight = capsule_weight + delta_weight
638
- else:
639
- interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat)
640
- cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
641
- scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
642
- interest_capsule = scalar_factor * interest_capsule
643
-
644
- interest_capsule = torch.reshape(interest_capsule, (-1, self.interest_num, self.embedding_dim))
645
-
646
- if self.relu_layer:
647
- interest_capsule = self.relu(interest_capsule)
648
-
649
- return interest_capsule
650
-
651
-
652
- class FFM(nn.Module):
653
- """The Field-aware Factorization Machine module, mentioned in the `FFM paper
654
- <https://dl.acm.org/doi/abs/10.1145/2959100.2959134>`. It explicitly models
655
- multi-channel second-order feature interactions, with each feature filed
656
- corresponding to one channel.
657
-
658
- Args:
659
- num_fields (int): number of feature fields.
660
- reduce_sum (bool): whether to sum in embed_dim (default = `True`).
661
-
662
- Shape:
663
- - Input: `(batch_size, num_fields, num_fields, embed_dim)`
664
- - Output: `(batch_size, num_fields*(num_fields-1)/2, 1)` or `(batch_size, num_fields*(num_fields-1)/2, embed_dim)`
665
- """
666
-
667
- def __init__(self, num_fields, reduce_sum=True):
668
- super().__init__()
669
- self.num_fields = num_fields
670
- self.reduce_sum = reduce_sum
671
-
672
- def forward(self, x):
673
- # compute (non-redundant) second order field-aware feature crossings
674
- crossed_embeddings = []
675
- for i in range(self.num_fields-1):
676
- for j in range(i+1, self.num_fields):
677
- crossed_embeddings.append(x[:, i, j, :] * x[:, j, i, :])
678
- crossed_embeddings = torch.stack(crossed_embeddings, dim=1)
679
-
680
- # if reduce_sum is true, the crossing operation is effectively inner product, other wise Hadamard-product
681
- if self.reduce_sum:
682
- crossed_embeddings = torch.sum(crossed_embeddings, dim=-1, keepdim=True)
683
- return crossed_embeddings
684
-
685
-
686
- class CEN(nn.Module):
687
- """The Compose-Excitation Network module, mentioned in the `FAT-DeepFFM paper
688
- <https://arxiv.org/abs/1905.06336>`, a modified version of
689
- `Squeeze-and-Excitation Network” (SENet) (Hu et al., 2017)`. It is used to
690
- highlight the importance of second-order feature crosses.
691
-
692
- Args:
693
- embed_dim (int): the dimensionality of categorical value embedding.
694
- num_field_crosses (int): the number of second order crosses between feature fields.
695
- reduction_ratio (int): the between the dimensions of input layer and hidden layer of the MLP module.
696
-
697
- Shape:
698
- - Input: `(batch_size, num_fields, num_fields, embed_dim)`
699
- - Output: `(batch_size, num_fields*(num_fields-1)/2 * embed_dim)`
700
- """
701
- def __init__(self, embed_dim, num_field_crosses, reduction_ratio):
702
- super().__init__()
703
-
704
- # convolution weight (Eq.7 FAT-DeepFFM)
705
- self.u = torch.nn.Parameter(torch.rand(num_field_crosses, embed_dim), requires_grad=True)
706
-
707
- # two FC layers that computes the field attention
708
- self.mlp_att = MLP(num_field_crosses, dims=[num_field_crosses//reduction_ratio, num_field_crosses], output_layer=False, activation="relu")
709
-
710
-
711
- def forward(self, em):
712
- # compute descriptor vector (Eq.7 FAT-DeepFFM), output shape [batch_size, num_field_crosses]
713
- d = F.relu((self.u.squeeze(0) * em).sum(-1))
714
-
715
- # compute field attention (Eq.9), output shape [batch_size, num_field_crosses]
716
- s = self.mlp_att(d)
717
-
718
- # rescale original embedding with field attention (Eq.10), output shape [batch_size, num_field_crosses, embed_dim]
719
- aem = s.unsqueeze(-1) * em
720
- return aem.flatten(start_dim=1)
1
+ from itertools import combinations
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from .activation import activation_layer
8
+ from .features import DenseFeature, SequenceFeature, SparseFeature
9
+
10
+
11
+ class PredictionLayer(nn.Module):
12
+ """Prediction Layer.
13
+
14
+ Args:
15
+ task_type (str): if `task_type='classification'`, then return sigmoid(x),
16
+ change the input logits to probability. if`task_type='regression'`, then return x.
17
+ """
18
+
19
+ def __init__(self, task_type='classification'):
20
+ super(PredictionLayer, self).__init__()
21
+ if task_type not in ["classification", "regression"]:
22
+ raise ValueError("task_type must be classification or regression")
23
+ self.task_type = task_type
24
+
25
+ def forward(self, x):
26
+ if self.task_type == "classification":
27
+ x = torch.sigmoid(x)
28
+ return x
29
+
30
+
31
+ class EmbeddingLayer(nn.Module):
32
+ """General Embedding Layer.
33
+ We save all the feature embeddings in embed_dict: `{feature_name : embedding table}`.
34
+
35
+
36
+ Args:
37
+ features (list): the list of `Feature Class`. It is means all the features which we want to create a embedding table.
38
+
39
+ Shape:
40
+ - Input:
41
+ x (dict): {feature_name: feature_value}, sequence feature value is a 2D tensor with shape:`(batch_size, seq_len)`,\
42
+ sparse/dense feature value is a 1D tensor with shape `(batch_size)`.
43
+ features (list): the list of `Feature Class`. It is means the current features which we want to do embedding lookup.
44
+ squeeze_dim (bool): whether to squeeze dim of output (default = `False`).
45
+ - Output:
46
+ - if input Dense: `(batch_size, num_features_dense)`.
47
+ - if input Sparse: `(batch_size, num_features, embed_dim)` or `(batch_size, num_features * embed_dim)`.
48
+ - if input Sequence: same with input sparse or `(batch_size, num_features_seq, seq_length, embed_dim)` when `pooling=="concat"`.
49
+ - if input Dense and Sparse/Sequence: `(batch_size, num_features_sparse * embed_dim)`. Note we must squeeze_dim for concat dense value with sparse embedding.
50
+ """
51
+
52
+ def __init__(self, features):
53
+ super().__init__()
54
+ self.features = features
55
+ self.embed_dict = nn.ModuleDict()
56
+ self.n_dense = 0
57
+
58
+ for fea in features:
59
+ if fea.name in self.embed_dict: # exist
60
+ continue
61
+ if isinstance(fea, SparseFeature) and fea.shared_with is None:
62
+ self.embed_dict[fea.name] = fea.get_embedding_layer()
63
+ elif isinstance(fea, SequenceFeature) and fea.shared_with is None:
64
+ self.embed_dict[fea.name] = fea.get_embedding_layer()
65
+ elif isinstance(fea, DenseFeature):
66
+ self.n_dense += 1
67
+
68
+ def forward(self, x, features, squeeze_dim=False):
69
+ sparse_emb, dense_values = [], []
70
+ sparse_exists, dense_exists = False, False
71
+ for fea in features:
72
+ if isinstance(fea, SparseFeature):
73
+ if fea.shared_with is None:
74
+ sparse_emb.append(self.embed_dict[fea.name](x[fea.name].long()).unsqueeze(1))
75
+ else:
76
+ sparse_emb.append(self.embed_dict[fea.shared_with](x[fea.name].long()).unsqueeze(1))
77
+ elif isinstance(fea, SequenceFeature):
78
+ if fea.pooling == "sum":
79
+ pooling_layer = SumPooling()
80
+ elif fea.pooling == "mean":
81
+ pooling_layer = AveragePooling()
82
+ elif fea.pooling == "concat":
83
+ pooling_layer = ConcatPooling()
84
+ else:
85
+ raise ValueError("Sequence pooling method supports only pooling in %s, got %s." % (["sum", "mean"], fea.pooling))
86
+ fea_mask = InputMask()(x, fea)
87
+ if fea.shared_with is None:
88
+ sparse_emb.append(pooling_layer(self.embed_dict[fea.name](x[fea.name].long()), fea_mask).unsqueeze(1))
89
+ else:
90
+ sparse_emb.append(pooling_layer(self.embed_dict[fea.shared_with](x[fea.name].long()), fea_mask).unsqueeze(1)) # shared specific sparse feature embedding
91
+ else:
92
+ dense_values.append(x[fea.name].float() if x[fea.name].float().dim() > 1 else x[fea.name].float().unsqueeze(1)) # .unsqueeze(1).unsqueeze(1)
93
+
94
+ if len(dense_values) > 0:
95
+ dense_exists = True
96
+ dense_values = torch.cat(dense_values, dim=1)
97
+ if len(sparse_emb) > 0:
98
+ sparse_exists = True
99
+ # TODO: support concat dynamic embed_dim in dim 2
100
+ # [batch_size, num_features, embed_dim]
101
+ sparse_emb = torch.cat(sparse_emb, dim=1)
102
+
103
+ if squeeze_dim: # Note: if the emb_dim of sparse features is different, we must squeeze_dim
104
+ if dense_exists and not sparse_exists: # only input dense features
105
+ return dense_values
106
+ elif not dense_exists and sparse_exists:
107
+ # squeeze dim to : [batch_size, num_features*embed_dim]
108
+ return sparse_emb.flatten(start_dim=1)
109
+ elif dense_exists and sparse_exists:
110
+ # concat dense value with sparse embedding
111
+ return torch.cat((sparse_emb.flatten(start_dim=1), dense_values), dim=1)
112
+ else:
113
+ raise ValueError("The input features can note be empty")
114
+ else:
115
+ if sparse_exists:
116
+ return sparse_emb # [batch_size, num_features, embed_dim]
117
+ else:
118
+ raise ValueError("If keep the original shape:[batch_size, num_features, embed_dim], expected %s in feature list, got %s" % ("SparseFeatures", features))
119
+
120
+
121
+ class InputMask(nn.Module):
122
+ """Return inputs mask from given features
123
+
124
+ Shape:
125
+ - Input:
126
+ x (dict): {feature_name: feature_value}, sequence feature value is a 2D tensor with shape:`(batch_size, seq_len)`,\
127
+ sparse/dense feature value is a 1D tensor with shape `(batch_size)`.
128
+ features (list or SparseFeature or SequenceFeature): Note that the elements in features are either all instances of SparseFeature or all instances of SequenceFeature.
129
+ - Output:
130
+ - if input Sparse: `(batch_size, num_features)`
131
+ - if input Sequence: `(batch_size, num_features_seq, seq_length)`
132
+ """
133
+
134
+ def __init__(self):
135
+ super().__init__()
136
+
137
+ def forward(self, x, features):
138
+ mask = []
139
+ if not isinstance(features, list):
140
+ features = [features]
141
+ for fea in features:
142
+ if isinstance(fea, SparseFeature) or isinstance(fea, SequenceFeature):
143
+ if fea.padding_idx is not None:
144
+ fea_mask = x[fea.name].long() != fea.padding_idx
145
+ else:
146
+ fea_mask = x[fea.name].long() != -1
147
+ mask.append(fea_mask.unsqueeze(1).float())
148
+ else:
149
+ raise ValueError("Only SparseFeature or SequenceFeature support to get mask.")
150
+ return torch.cat(mask, dim=1)
151
+
152
+
153
+ class LR(nn.Module):
154
+ """Logistic Regression Module. It is the one Non-linear
155
+ transformation for input feature.
156
+
157
+ Args:
158
+ input_dim (int): input size of Linear module.
159
+ sigmoid (bool): whether to add sigmoid function before output.
160
+
161
+ Shape:
162
+ - Input: `(batch_size, input_dim)`
163
+ - Output: `(batch_size, 1)`
164
+ """
165
+
166
+ def __init__(self, input_dim, sigmoid=False):
167
+ super().__init__()
168
+ self.sigmoid = sigmoid
169
+ self.fc = nn.Linear(input_dim, 1, bias=True)
170
+
171
+ def forward(self, x):
172
+ if self.sigmoid:
173
+ return torch.sigmoid(self.fc(x))
174
+ else:
175
+ return self.fc(x)
176
+
177
+
178
+ class ConcatPooling(nn.Module):
179
+ """Keep the origin sequence embedding shape
180
+
181
+ Shape:
182
+ - Input: `(batch_size, seq_length, embed_dim)`
183
+ - Output: `(batch_size, seq_length, embed_dim)`
184
+ """
185
+
186
+ def __init__(self):
187
+ super().__init__()
188
+
189
+ def forward(self, x, mask=None):
190
+ return x
191
+
192
+
193
+ class AveragePooling(nn.Module):
194
+ """Pooling the sequence embedding matrix by `mean`.
195
+
196
+ Shape:
197
+ - Input
198
+ x: `(batch_size, seq_length, embed_dim)`
199
+ mask: `(batch_size, 1, seq_length)`
200
+ - Output: `(batch_size, embed_dim)`
201
+ """
202
+
203
+ def __init__(self):
204
+ super().__init__()
205
+
206
+ def forward(self, x, mask=None):
207
+ if mask is None:
208
+ return torch.mean(x, dim=1)
209
+ else:
210
+ sum_pooling_matrix = torch.bmm(mask, x).squeeze(1)
211
+ non_padding_length = mask.sum(dim=-1)
212
+ return sum_pooling_matrix / (non_padding_length.float() + 1e-16)
213
+
214
+
215
+ class SumPooling(nn.Module):
216
+ """Pooling the sequence embedding matrix by `sum`.
217
+
218
+ Shape:
219
+ - Input
220
+ x: `(batch_size, seq_length, embed_dim)`
221
+ mask: `(batch_size, 1, seq_length)`
222
+ - Output: `(batch_size, embed_dim)`
223
+ """
224
+
225
+ def __init__(self):
226
+ super().__init__()
227
+
228
+ def forward(self, x, mask=None):
229
+ if mask is None:
230
+ return torch.sum(x, dim=1)
231
+ else:
232
+ return torch.bmm(mask, x).squeeze(1)
233
+
234
+
235
+ class MLP(nn.Module):
236
+ """Multi Layer Perceptron Module, it is the most widely used module for
237
+ learning feature. Note we default add `BatchNorm1d` and `Activation`
238
+ `Dropout` for each `Linear` Module.
239
+
240
+ Args:
241
+ input dim (int): input size of the first Linear Layer.
242
+ output_layer (bool): whether this MLP module is the output layer. If `True`, then append one Linear(*,1) module.
243
+ dims (list): output size of Linear Layer (default=[]).
244
+ dropout (float): probability of an element to be zeroed (default = 0.5).
245
+ activation (str): the activation function, support `[sigmoid, relu, prelu, dice, softmax]` (default='relu').
246
+
247
+ Shape:
248
+ - Input: `(batch_size, input_dim)`
249
+ - Output: `(batch_size, 1)` or `(batch_size, dims[-1])`
250
+ """
251
+
252
+ def __init__(self, input_dim, output_layer=True, dims=None, dropout=0, activation="relu"):
253
+ super().__init__()
254
+ if dims is None:
255
+ dims = []
256
+ layers = list()
257
+ for i_dim in dims:
258
+ layers.append(nn.Linear(input_dim, i_dim))
259
+ layers.append(nn.BatchNorm1d(i_dim))
260
+ layers.append(activation_layer(activation))
261
+ layers.append(nn.Dropout(p=dropout))
262
+ input_dim = i_dim
263
+ if output_layer:
264
+ layers.append(nn.Linear(input_dim, 1))
265
+ self.mlp = nn.Sequential(*layers)
266
+
267
+ def forward(self, x):
268
+ return self.mlp(x)
269
+
270
+
271
+ class FM(nn.Module):
272
+ """The Factorization Machine module, mentioned in the `DeepFM paper
273
+ <https://arxiv.org/pdf/1703.04247.pdf>`. It is used to learn 2nd-order
274
+ feature interactions.
275
+
276
+ Args:
277
+ reduce_sum (bool): whether to sum in embed_dim (default = `True`).
278
+
279
+ Shape:
280
+ - Input: `(batch_size, num_features, embed_dim)`
281
+ - Output: `(batch_size, 1)`` or ``(batch_size, embed_dim)`
282
+ """
283
+
284
+ def __init__(self, reduce_sum=True):
285
+ super().__init__()
286
+ self.reduce_sum = reduce_sum
287
+
288
+ def forward(self, x):
289
+ square_of_sum = torch.sum(x, dim=1)**2
290
+ sum_of_square = torch.sum(x**2, dim=1)
291
+ ix = square_of_sum - sum_of_square
292
+ if self.reduce_sum:
293
+ ix = torch.sum(ix, dim=1, keepdim=True)
294
+ return 0.5 * ix
295
+
296
+
297
+ class CIN(nn.Module):
298
+ """Compressed Interaction Network
299
+
300
+ Args:
301
+ input_dim (int): input dim of input tensor.
302
+ cin_size (list[int]): out channels of Conv1d.
303
+
304
+ Shape:
305
+ - Input: `(batch_size, num_features, embed_dim)`
306
+ - Output: `(batch_size, 1)`
307
+ """
308
+
309
+ def __init__(self, input_dim, cin_size, split_half=True):
310
+ super().__init__()
311
+ self.num_layers = len(cin_size)
312
+ self.split_half = split_half
313
+ self.conv_layers = torch.nn.ModuleList()
314
+ prev_dim, fc_input_dim = input_dim, 0
315
+ for i in range(self.num_layers):
316
+ cross_layer_size = cin_size[i]
317
+ self.conv_layers.append(torch.nn.Conv1d(input_dim * prev_dim, cross_layer_size, 1, stride=1, dilation=1, bias=True))
318
+ if self.split_half and i != self.num_layers - 1:
319
+ cross_layer_size //= 2
320
+ prev_dim = cross_layer_size
321
+ fc_input_dim += prev_dim
322
+ self.fc = torch.nn.Linear(fc_input_dim, 1)
323
+
324
+ def forward(self, x):
325
+ xs = list()
326
+ x0, h = x.unsqueeze(2), x
327
+ for i in range(self.num_layers):
328
+ x = x0 * h.unsqueeze(1)
329
+ batch_size, f0_dim, fin_dim, embed_dim = x.shape
330
+ x = x.view(batch_size, f0_dim * fin_dim, embed_dim)
331
+ x = F.relu(self.conv_layers[i](x))
332
+ if self.split_half and i != self.num_layers - 1:
333
+ x, h = torch.split(x, x.shape[1] // 2, dim=1)
334
+ else:
335
+ h = x
336
+ xs.append(x)
337
+ return self.fc(torch.sum(torch.cat(xs, dim=1), 2))
338
+
339
+
340
+ class CrossLayer(nn.Module):
341
+ """
342
+ Cross layer.
343
+ Args:
344
+ input_dim (int): input dim of input tensor
345
+ """
346
+
347
+ def __init__(self, input_dim):
348
+ super(CrossLayer, self).__init__()
349
+ self.w = torch.nn.Linear(input_dim, 1, bias=False)
350
+ self.b = torch.nn.Parameter(torch.zeros(input_dim))
351
+
352
+ def forward(self, x_0, x_i):
353
+ x = self.w(x_i) * x_0 + self.b
354
+ return x
355
+
356
+
357
+ class CrossNetwork(nn.Module):
358
+ """CrossNetwork mentioned in the DCN paper.
359
+
360
+ Args:
361
+ input_dim (int): input dim of input tensor
362
+
363
+ Shape:
364
+ - Input: `(batch_size, *)`
365
+ - Output: `(batch_size, *)`
366
+
367
+ """
368
+
369
+ def __init__(self, input_dim, num_layers):
370
+ super().__init__()
371
+ self.num_layers = num_layers
372
+ self.w = torch.nn.ModuleList([torch.nn.Linear(input_dim, 1, bias=False) for _ in range(num_layers)])
373
+ self.b = torch.nn.ParameterList([torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)])
374
+
375
+ def forward(self, x):
376
+ """
377
+ :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
378
+ """
379
+ x0 = x
380
+ for i in range(self.num_layers):
381
+ xw = self.w[i](x)
382
+ x = x0 * xw + self.b[i] + x
383
+ return x
384
+
385
+
386
+ class CrossNetV2(nn.Module):
387
+
388
+ def __init__(self, input_dim, num_layers):
389
+ super().__init__()
390
+ self.num_layers = num_layers
391
+ self.w = torch.nn.ModuleList([torch.nn.Linear(input_dim, input_dim, bias=False) for _ in range(num_layers)])
392
+ self.b = torch.nn.ParameterList([torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)])
393
+
394
+ def forward(self, x):
395
+ x0 = x
396
+ for i in range(self.num_layers):
397
+ x = x0 * self.w[i](x) + self.b[i] + x
398
+ return x
399
+
400
+
401
+ class CrossNetMix(nn.Module):
402
+ """ CrossNetMix improves CrossNetwork by:
403
+ 1. add MOE to learn feature interactions in different subspaces
404
+ 2. add nonlinear transformations in low-dimensional space
405
+ :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
406
+ """
407
+
408
+ def __init__(self, input_dim, num_layers=2, low_rank=32, num_experts=4):
409
+ super(CrossNetMix, self).__init__()
410
+ self.num_layers = num_layers
411
+ self.num_experts = num_experts
412
+
413
+ # U: (input_dim, low_rank)
414
+ self.u_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(torch.empty(num_experts, input_dim, low_rank))) for i in range(self.num_layers)])
415
+ # V: (input_dim, low_rank)
416
+ self.v_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(torch.empty(num_experts, input_dim, low_rank))) for i in range(self.num_layers)])
417
+ # C: (low_rank, low_rank)
418
+ self.c_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(torch.empty(num_experts, low_rank, low_rank))) for i in range(self.num_layers)])
419
+ self.gating = nn.ModuleList([nn.Linear(input_dim, 1, bias=False) for i in range(self.num_experts)])
420
+
421
+ self.bias = torch.nn.ParameterList([nn.Parameter(nn.init.zeros_(torch.empty(input_dim, 1))) for i in range(self.num_layers)])
422
+
423
+ def forward(self, x):
424
+ x_0 = x.unsqueeze(2) # (bs, in_features, 1)
425
+ x_l = x_0
426
+ for i in range(self.num_layers):
427
+ output_of_experts = []
428
+ gating_score_experts = []
429
+ for expert_id in range(self.num_experts):
430
+ # (1) G(x_l)
431
+ # compute the gating score by x_l
432
+ gating_score_experts.append(self.gating[expert_id](x_l.squeeze(2)))
433
+
434
+ # (2) E(x_l)
435
+ # project the input x_l to $\mathbb{R}^{r}$
436
+ v_x = torch.matmul(self.v_list[i][expert_id].t(), x_l) # (bs, low_rank, 1)
437
+
438
+ # nonlinear activation in low rank space
439
+ v_x = torch.tanh(v_x)
440
+ v_x = torch.matmul(self.c_list[i][expert_id], v_x)
441
+ v_x = torch.tanh(v_x)
442
+
443
+ # project back to $\mathbb{R}^{d}$
444
+ uv_x = torch.matmul(self.u_list[i][expert_id], v_x) # (bs, in_features, 1)
445
+
446
+ dot_ = uv_x + self.bias[i]
447
+ dot_ = x_0 * dot_ # Hadamard-product
448
+
449
+ output_of_experts.append(dot_.squeeze(2))
450
+
451
+
452
+ # (3) mixture of low-rank experts
453
+ output_of_experts = torch.stack(output_of_experts, 2) # (bs, in_features, num_experts)
454
+ gating_score_experts = torch.stack(gating_score_experts, 1) # (bs, num_experts, 1)
455
+ moe_out = torch.matmul(output_of_experts, gating_score_experts.softmax(1))
456
+ x_l = moe_out + x_l # (bs, in_features, 1)
457
+
458
+ x_l = x_l.squeeze() # (bs, in_features)
459
+ return x_l
460
+
461
+
462
+ class SENETLayer(nn.Module):
463
+ """
464
+ A weighted feature gating system in the SENet paper
465
+ Args:
466
+ num_fields (int): number of feature fields
467
+
468
+ Shape:
469
+ - num_fields: `(batch_size, *)`
470
+ - Output: `(batch_size, *)`
471
+ """
472
+
473
+ def __init__(self, num_fields, reduction_ratio=3):
474
+ super(SENETLayer, self).__init__()
475
+ reduced_size = max(1, int(num_fields / reduction_ratio))
476
+ self.mlp = nn.Sequential(nn.Linear(num_fields, reduced_size, bias=False), nn.ReLU(), nn.Linear(reduced_size, num_fields, bias=False), nn.ReLU())
477
+
478
+ def forward(self, x):
479
+ z = torch.mean(x, dim=-1, out=None)
480
+ a = self.mlp(z)
481
+ v = x * a.unsqueeze(-1)
482
+ return v
483
+
484
+
485
+ class BiLinearInteractionLayer(nn.Module):
486
+ """
487
+ Bilinear feature interaction module, which is an improved model of the FFM model
488
+ Args:
489
+ num_fields (int): number of feature fields
490
+ bilinear_type(str): the type bilinear interaction function
491
+ Shape:
492
+ - num_fields: `(batch_size, *)`
493
+ - Output: `(batch_size, *)`
494
+ """
495
+
496
+ def __init__(self, input_dim, num_fields, bilinear_type="field_interaction"):
497
+ super(BiLinearInteractionLayer, self).__init__()
498
+ self.bilinear_type = bilinear_type
499
+ if self.bilinear_type == "field_all":
500
+ self.bilinear_layer = nn.Linear(input_dim, input_dim, bias=False)
501
+ elif self.bilinear_type == "field_each":
502
+ self.bilinear_layer = nn.ModuleList([nn.Linear(input_dim, input_dim, bias=False) for i in range(num_fields)])
503
+ elif self.bilinear_type == "field_interaction":
504
+ self.bilinear_layer = nn.ModuleList([nn.Linear(input_dim, input_dim, bias=False) for i, j in combinations(range(num_fields), 2)])
505
+ else:
506
+ raise NotImplementedError()
507
+
508
+ def forward(self, x):
509
+ feature_emb = torch.split(x, 1, dim=1)
510
+ if self.bilinear_type == "field_all":
511
+ bilinear_list = [self.bilinear_layer(v_i) * v_j for v_i, v_j in combinations(feature_emb, 2)]
512
+ elif self.bilinear_type == "field_each":
513
+ bilinear_list = [self.bilinear_layer[i](feature_emb[i]) * feature_emb[j] for i, j in combinations(range(len(feature_emb)), 2)]
514
+ elif self.bilinear_type == "field_interaction":
515
+ bilinear_list = [self.bilinear_layer[i](v[0]) * v[1] for i, v in enumerate(combinations(feature_emb, 2))]
516
+ return torch.cat(bilinear_list, dim=1)
517
+
518
+
519
+ class MultiInterestSA(nn.Module):
520
+ """MultiInterest Attention mentioned in the Comirec paper.
521
+
522
+ Args:
523
+ embedding_dim (int): embedding dim of item embedding
524
+ interest_num (int): num of interest
525
+ hidden_dim (int): hidden dim
526
+
527
+ Shape:
528
+ - Input: seq_emb : (batch,seq,emb)
529
+ mask : (batch,seq,1)
530
+ - Output: `(batch_size, interest_num, embedding_dim)`
531
+
532
+ """
533
+
534
+ def __init__(self, embedding_dim, interest_num, hidden_dim=None):
535
+ super(MultiInterestSA, self).__init__()
536
+ self.embedding_dim = embedding_dim
537
+ self.interest_num = interest_num
538
+ if hidden_dim is None:
539
+ self.hidden_dim = self.embedding_dim * 4
540
+ self.W1 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.hidden_dim), requires_grad=True)
541
+ self.W2 = torch.nn.Parameter(torch.rand(self.hidden_dim, self.interest_num), requires_grad=True)
542
+ self.W3 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.embedding_dim), requires_grad=True)
543
+
544
+ def forward(self, seq_emb, mask=None):
545
+ H = torch.einsum('bse, ed -> bsd', seq_emb, self.W1).tanh()
546
+ if mask is not None:
547
+ A = torch.einsum('bsd, dk -> bsk', H, self.W2) + - \
548
+ 1.e9 * (1 - mask.float())
549
+ A = F.softmax(A, dim=1)
550
+ else:
551
+ A = F.softmax(torch.einsum('bsd, dk -> bsk', H, self.W2), dim=1)
552
+ A = A.permute(0, 2, 1)
553
+ multi_interest_emb = torch.matmul(A, seq_emb)
554
+ return multi_interest_emb
555
+
556
+
557
+ class CapsuleNetwork(nn.Module):
558
+ """CapsuleNetwork mentioned in the Comirec and MIND paper.
559
+
560
+ Args:
561
+ hidden_size (int): embedding dim of item embedding
562
+ seq_len (int): length of the item sequence
563
+ bilinear_type (int): 0 for MIND, 2 for ComirecDR
564
+ interest_num (int): num of interest
565
+ routing_times (int): routing times
566
+
567
+ Shape:
568
+ - Input: seq_emb : (batch,seq,emb)
569
+ mask : (batch,seq,1)
570
+ - Output: `(batch_size, interest_num, embedding_dim)`
571
+
572
+ """
573
+
574
+ def __init__(self, embedding_dim, seq_len, bilinear_type=2, interest_num=4, routing_times=3, relu_layer=False):
575
+ super(CapsuleNetwork, self).__init__()
576
+ self.embedding_dim = embedding_dim # h
577
+ self.seq_len = seq_len # s
578
+ self.bilinear_type = bilinear_type
579
+ self.interest_num = interest_num
580
+ self.routing_times = routing_times
581
+
582
+ self.relu_layer = relu_layer
583
+ self.stop_grad = True
584
+ self.relu = nn.Sequential(nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), nn.ReLU())
585
+ if self.bilinear_type == 0: # MIND
586
+ self.linear = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
587
+ elif self.bilinear_type == 1:
588
+ self.linear = nn.Linear(self.embedding_dim, self.embedding_dim * self.interest_num, bias=False)
589
+ else:
590
+ self.w = nn.Parameter(torch.Tensor(1, self.seq_len, self.interest_num * self.embedding_dim, self.embedding_dim))
591
+
592
+ def forward(self, item_eb, mask):
593
+ if self.bilinear_type == 0:
594
+ item_eb_hat = self.linear(item_eb)
595
+ item_eb_hat = item_eb_hat.repeat(1, 1, self.interest_num)
596
+ elif self.bilinear_type == 1:
597
+ item_eb_hat = self.linear(item_eb)
598
+ else:
599
+ u = torch.unsqueeze(item_eb, dim=2)
600
+ item_eb_hat = torch.sum(self.w[:, :self.seq_len, :, :] * u, dim=3)
601
+
602
+ item_eb_hat = torch.reshape(item_eb_hat, (-1, self.seq_len, self.interest_num, self.embedding_dim))
603
+ item_eb_hat = torch.transpose(item_eb_hat, 1, 2).contiguous()
604
+ item_eb_hat = torch.reshape(item_eb_hat, (-1, self.interest_num, self.seq_len, self.embedding_dim))
605
+
606
+ if self.stop_grad:
607
+ item_eb_hat_iter = item_eb_hat.detach()
608
+ else:
609
+ item_eb_hat_iter = item_eb_hat
610
+
611
+ if self.bilinear_type > 0:
612
+ capsule_weight = torch.zeros(item_eb_hat.shape[0], self.interest_num, self.seq_len, device=item_eb.device, requires_grad=False)
613
+ else:
614
+ capsule_weight = torch.randn(item_eb_hat.shape[0], self.interest_num, self.seq_len, device=item_eb.device, requires_grad=False)
615
+
616
+ for i in range(self.routing_times): # 动态路由传播3次
617
+ atten_mask = torch.unsqueeze(mask, 1).repeat(1, self.interest_num, 1)
618
+ paddings = torch.zeros_like(atten_mask, dtype=torch.float)
619
+
620
+ capsule_softmax_weight = F.softmax(capsule_weight, dim=-1)
621
+ capsule_softmax_weight = torch.where(torch.eq(atten_mask, 0), paddings, capsule_softmax_weight)
622
+ capsule_softmax_weight = torch.unsqueeze(capsule_softmax_weight, 2)
623
+
624
+ if i < 2:
625
+ interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat_iter)
626
+ cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
627
+ scalar_factor = cap_norm / \
628
+ (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
629
+ interest_capsule = scalar_factor * interest_capsule
630
+
631
+ delta_weight = torch.matmul(item_eb_hat_iter, torch.transpose(interest_capsule, 2, 3).contiguous())
632
+ delta_weight = torch.reshape(delta_weight, (-1, self.interest_num, self.seq_len))
633
+ capsule_weight = capsule_weight + delta_weight
634
+ else:
635
+ interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat)
636
+ cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
637
+ scalar_factor = cap_norm / \
638
+ (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
639
+ interest_capsule = scalar_factor * interest_capsule
640
+
641
+ interest_capsule = torch.reshape(interest_capsule, (-1, self.interest_num, self.embedding_dim))
642
+
643
+ if self.relu_layer:
644
+ interest_capsule = self.relu(interest_capsule)
645
+
646
+ return interest_capsule
647
+
648
+
649
+ class FFM(nn.Module):
650
+ """The Field-aware Factorization Machine module, mentioned in the `FFM paper
651
+ <https://dl.acm.org/doi/abs/10.1145/2959100.2959134>`. It explicitly models
652
+ multi-channel second-order feature interactions, with each feature filed
653
+ corresponding to one channel.
654
+
655
+ Args:
656
+ num_fields (int): number of feature fields.
657
+ reduce_sum (bool): whether to sum in embed_dim (default = `True`).
658
+
659
+ Shape:
660
+ - Input: `(batch_size, num_fields, num_fields, embed_dim)`
661
+ - Output: `(batch_size, num_fields*(num_fields-1)/2, 1)` or `(batch_size, num_fields*(num_fields-1)/2, embed_dim)`
662
+ """
663
+
664
+ def __init__(self, num_fields, reduce_sum=True):
665
+ super().__init__()
666
+ self.num_fields = num_fields
667
+ self.reduce_sum = reduce_sum
668
+
669
+ def forward(self, x):
670
+ # compute (non-redundant) second order field-aware feature crossings
671
+ crossed_embeddings = []
672
+ for i in range(self.num_fields - 1):
673
+ for j in range(i + 1, self.num_fields):
674
+ crossed_embeddings.append(x[:, i, j, :] * x[:, j, i, :])
675
+ crossed_embeddings = torch.stack(crossed_embeddings, dim=1)
676
+
677
+ # if reduce_sum is true, the crossing operation is effectively inner
678
+ # product, other wise Hadamard-product
679
+ if self.reduce_sum:
680
+ crossed_embeddings = torch.sum(crossed_embeddings, dim=-1, keepdim=True)
681
+ return crossed_embeddings
682
+
683
+
684
+ class CEN(nn.Module):
685
+ """The Compose-Excitation Network module, mentioned in the `FAT-DeepFFM paper
686
+ <https://arxiv.org/abs/1905.06336>`, a modified version of
687
+ `Squeeze-and-Excitation Network (SENet) (Hu et al., 2017)`. It is used to
688
+ highlight the importance of second-order feature crosses.
689
+
690
+ Args:
691
+ embed_dim (int): the dimensionality of categorical value embedding.
692
+ num_field_crosses (int): the number of second order crosses between feature fields.
693
+ reduction_ratio (int): the between the dimensions of input layer and hidden layer of the MLP module.
694
+
695
+ Shape:
696
+ - Input: `(batch_size, num_fields, num_fields, embed_dim)`
697
+ - Output: `(batch_size, num_fields*(num_fields-1)/2 * embed_dim)`
698
+ """
699
+
700
+ def __init__(self, embed_dim, num_field_crosses, reduction_ratio):
701
+ super().__init__()
702
+
703
+ # convolution weight (Eq.7 FAT-DeepFFM)
704
+ self.u = torch.nn.Parameter(torch.rand(num_field_crosses, embed_dim), requires_grad=True)
705
+
706
+ # two FC layers that computes the field attention
707
+ self.mlp_att = MLP(num_field_crosses, dims=[num_field_crosses // reduction_ratio, num_field_crosses], output_layer=False, activation="relu")
708
+
709
+ def forward(self, em):
710
+ # compute descriptor vector (Eq.7 FAT-DeepFFM), output shape
711
+ # [batch_size, num_field_crosses]
712
+ d = F.relu((self.u.squeeze(0) * em).sum(-1))
713
+
714
+ # compute field attention (Eq.9), output shape [batch_size,
715
+ # num_field_crosses]
716
+ s = self.mlp_att(d)
717
+
718
+ # rescale original embedding with field attention (Eq.10), output shape
719
+ # [batch_size, num_field_crosses, embed_dim]
720
+ aem = s.unsqueeze(-1) * em
721
+ return aem.flatten(start_dim=1)
722
+
723
+
724
+ # ============ HSTU Layers (新增) ============
725
+
726
+
727
+ class HSTULayer(nn.Module):
728
+ """Single HSTU layer.
729
+
730
+ This layer implements the core HSTU "sequential transduction unit": a
731
+ multi-head self-attention block with gating and a position-wise FFN, plus
732
+ residual connections and LayerNorm.
733
+
734
+ Args:
735
+ d_model (int): Hidden dimension of the model. Default: 512.
736
+ n_heads (int): Number of attention heads. Default: 8.
737
+ dqk (int): Dimension of query/key per head. Default: 64.
738
+ dv (int): Dimension of value per head. Default: 64.
739
+ dropout (float): Dropout rate applied in the layer. Default: 0.1.
740
+ use_rel_pos_bias (bool): Whether to use relative position bias.
741
+
742
+ Shape:
743
+ - Input: ``(batch_size, seq_len, d_model)``
744
+ - Output: ``(batch_size, seq_len, d_model)``
745
+
746
+ Example:
747
+ >>> layer = HSTULayer(d_model=512, n_heads=8)
748
+ >>> x = torch.randn(32, 256, 512)
749
+ >>> output = layer(x)
750
+ >>> output.shape
751
+ torch.Size([32, 256, 512])
752
+ """
753
+
754
+ def __init__(self, d_model=512, n_heads=8, dqk=64, dv=64, dropout=0.1, use_rel_pos_bias=True):
755
+ super().__init__()
756
+ self.d_model = d_model
757
+ self.n_heads = n_heads
758
+ self.dqk = dqk
759
+ self.dv = dv
760
+ self.dropout_rate = dropout
761
+ self.use_rel_pos_bias = use_rel_pos_bias
762
+
763
+ # Validate dimensions
764
+ assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
765
+
766
+ # Projection 1: d_model -> 2*n_heads*dqk + 2*n_heads*dv
767
+ proj1_out_dim = 2 * n_heads * dqk + 2 * n_heads * dv
768
+ self.proj1 = nn.Linear(d_model, proj1_out_dim)
769
+
770
+ # Projection 2: n_heads*dv -> d_model
771
+ self.proj2 = nn.Linear(n_heads * dv, d_model)
772
+
773
+ # Feed-forward network (FFN)
774
+ # Standard Transformer uses 4*d_model as the hidden dimension of FFN
775
+ ffn_hidden_dim = 4 * d_model
776
+ self.ffn = nn.Sequential(nn.Linear(d_model, ffn_hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(ffn_hidden_dim, d_model), nn.Dropout(dropout))
777
+
778
+ # Layer normalization
779
+ self.norm1 = nn.LayerNorm(d_model)
780
+ self.norm2 = nn.LayerNorm(d_model)
781
+
782
+ # Dropout
783
+ self.dropout = nn.Dropout(dropout)
784
+
785
+ # Scaling factor for attention scores
786
+ self.scale = 1.0 / (dqk**0.5)
787
+
788
+ def forward(self, x, rel_pos_bias=None):
789
+ """Forward pass of a single HSTU layer.
790
+
791
+ Args:
792
+ x (Tensor): Input tensor of shape ``(batch_size, seq_len, d_model)``.
793
+ rel_pos_bias (Tensor, optional): Relative position bias of shape
794
+ ``(1, n_heads, seq_len, seq_len)``.
795
+
796
+ Returns:
797
+ Tensor: Output tensor of shape ``(batch_size, seq_len, d_model)``.
798
+ """
799
+ batch_size, seq_len, _ = x.shape
800
+
801
+ # Residual connection
802
+ residual = x
803
+
804
+ # Layer normalization
805
+ x = self.norm1(x)
806
+
807
+ # Projection 1: (B, L, D) -> (B, L, 2*H*dqk + 2*H*dv)
808
+ proj_out = self.proj1(x)
809
+
810
+ # Split into Q, K, U, V
811
+ # Q, K: (B, L, H, dqk)
812
+ # U, V: (B, L, H, dv)
813
+ q = proj_out[..., :self.n_heads * self.dqk].reshape(batch_size, seq_len, self.n_heads, self.dqk)
814
+ k = proj_out[..., self.n_heads * self.dqk:2 * self.n_heads * self.dqk].reshape(batch_size, seq_len, self.n_heads, self.dqk)
815
+ u = proj_out[..., 2 * self.n_heads * self.dqk:2 * self.n_heads * self.dqk + self.n_heads * self.dv].reshape(batch_size, seq_len, self.n_heads, self.dv)
816
+ v = proj_out[..., 2 * self.n_heads * self.dqk + self.n_heads * self.dv:].reshape(batch_size, seq_len, self.n_heads, self.dv)
817
+
818
+ # Transpose to (B, H, L, dqk/dv)
819
+ q = q.transpose(1, 2) # (B, H, L, dqk)
820
+ k = k.transpose(1, 2) # (B, H, L, dqk)
821
+ u = u.transpose(1, 2) # (B, H, L, dv)
822
+ v = v.transpose(1, 2) # (B, H, L, dv)
823
+
824
+ # Compute attention scores: (B, H, L, L)
825
+ scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
826
+
827
+ # Add causal mask (prevent attending to future positions)
828
+ # For generative models this is required so that position i only attends
829
+ # to positions <= i.
830
+ causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool))
831
+ scores = scores.masked_fill(~causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
832
+
833
+ # Add relative position bias if provided
834
+ if rel_pos_bias is not None:
835
+ scores = scores + rel_pos_bias
836
+
837
+ # Softmax over attention scores
838
+ attn_weights = F.softmax(scores, dim=-1)
839
+ attn_weights = self.dropout(attn_weights)
840
+
841
+ # Attention output: (B, H, L, dv)
842
+ attn_output = torch.matmul(attn_weights, v)
843
+
844
+ # Gating mechanism: apply a learned gate on top of attention output
845
+ # First transpose back to (B, L, H, dv)
846
+ attn_output = attn_output.transpose(1, 2) # (B, L, H, dv)
847
+ u = u.transpose(1, 2) # (B, L, H, dv)
848
+
849
+ # Apply element-wise gate: (B, L, H, dv)
850
+ gated_output = attn_output * torch.sigmoid(u)
851
+
852
+ # Merge heads: (B, L, H*dv)
853
+ gated_output = gated_output.reshape(batch_size, seq_len, self.n_heads * self.dv)
854
+
855
+ # Projection 2: (B, L, H*dv) -> (B, L, D)
856
+ output = self.proj2(gated_output)
857
+ output = self.dropout(output)
858
+
859
+ # Residual connection
860
+ output = output + residual
861
+
862
+ # Second residual block: LayerNorm + FFN + residual connection
863
+ residual = output
864
+ output = self.norm2(output)
865
+ output = self.ffn(output)
866
+ output = output + residual
867
+
868
+ return output
869
+
870
+
871
+ class HSTUBlock(nn.Module):
872
+ """Stacked HSTU block.
873
+
874
+ This block stacks multiple :class:`HSTULayer` layers to form a deep HSTU
875
+ encoder for sequential recommendation.
876
+
877
+ Args:
878
+ d_model (int): Hidden dimension of the model. Default: 512.
879
+ n_heads (int): Number of attention heads. Default: 8.
880
+ n_layers (int): Number of stacked HSTU layers. Default: 4.
881
+ dqk (int): Dimension of query/key per head. Default: 64.
882
+ dv (int): Dimension of value per head. Default: 64.
883
+ dropout (float): Dropout rate applied in each layer. Default: 0.1.
884
+ use_rel_pos_bias (bool): Whether to use relative position bias.
885
+
886
+ Shape:
887
+ - Input: ``(batch_size, seq_len, d_model)``
888
+ - Output: ``(batch_size, seq_len, d_model)``
889
+
890
+ Example:
891
+ >>> block = HSTUBlock(d_model=512, n_heads=8, n_layers=4)
892
+ >>> x = torch.randn(32, 256, 512)
893
+ >>> output = block(x)
894
+ >>> output.shape
895
+ torch.Size([32, 256, 512])
896
+ """
897
+
898
+ def __init__(self, d_model=512, n_heads=8, n_layers=4, dqk=64, dv=64, dropout=0.1, use_rel_pos_bias=True):
899
+ super().__init__()
900
+ self.d_model = d_model
901
+ self.n_heads = n_heads
902
+ self.n_layers = n_layers
903
+
904
+ # Create a stack of HSTULayer modules
905
+ self.layers = nn.ModuleList([HSTULayer(d_model=d_model, n_heads=n_heads, dqk=dqk, dv=dv, dropout=dropout, use_rel_pos_bias=use_rel_pos_bias) for _ in range(n_layers)])
906
+
907
+ def forward(self, x, rel_pos_bias=None):
908
+ """Forward pass through all stacked HSTULayer modules.
909
+
910
+ Args:
911
+ x (Tensor): Input tensor of shape ``(batch_size, seq_len, d_model)``.
912
+ rel_pos_bias (Tensor, optional): Relative position bias shared across
913
+ all layers.
914
+
915
+ Returns:
916
+ Tensor: Output tensor of shape ``(batch_size, seq_len, d_model)``.
917
+ """
918
+ for layer in self.layers:
919
+ x = layer(x, rel_pos_bias=rel_pos_bias)
920
+ return x
921
+
922
+
923
+ class InteractingLayer(nn.Module):
924
+ """Multi-head Self-Attention based Interacting Layer, used in AutoInt model.
925
+
926
+ Args:
927
+ embed_dim (int): the embedding dimension.
928
+ num_heads (int): the number of attention heads (default=2).
929
+ dropout (float): the dropout rate (default=0.0).
930
+ residual (bool): whether to use residual connection (default=True).
931
+
932
+ Shape:
933
+ - Input: `(batch_size, num_fields, embed_dim)`
934
+ - Output: `(batch_size, num_fields, embed_dim)`
935
+ """
936
+
937
+ def __init__(self, embed_dim, num_heads=2, dropout=0.0, residual=True):
938
+ super().__init__()
939
+ if embed_dim % num_heads != 0:
940
+ raise ValueError("embed_dim must be divisible by num_heads")
941
+
942
+ self.embed_dim = embed_dim
943
+ self.num_heads = num_heads
944
+ self.head_dim = embed_dim // num_heads
945
+ self.scale = self.head_dim**-0.5
946
+ self.residual = residual
947
+
948
+ self.W_Q = nn.Linear(embed_dim, embed_dim, bias=False)
949
+ self.W_K = nn.Linear(embed_dim, embed_dim, bias=False)
950
+ self.W_V = nn.Linear(embed_dim, embed_dim, bias=False)
951
+
952
+ # Residual connection
953
+ self.W_Res = nn.Linear(embed_dim, embed_dim, bias=False) if residual else None
954
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else None
955
+
956
+ def forward(self, x):
957
+ """
958
+ Args:
959
+ x: input tensor with shape (batch_size, num_fields, embed_dim)
960
+ """
961
+ batch_size, num_fields, embed_dim = x.shape
962
+
963
+ # Linear projections
964
+ Q = self.W_Q(x) # (batch_size, num_fields, embed_dim)
965
+ K = self.W_K(x) # (batch_size, num_fields, embed_dim)
966
+ V = self.W_V(x) # (batch_size, num_fields, embed_dim)
967
+
968
+ # Reshape for multi-head attention
969
+ # (batch_size, num_heads, num_fields, head_dim)
970
+ Q = Q.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
971
+ K = K.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
972
+ V = V.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
973
+
974
+ # Scaled dot-product attention
975
+ # (batch_size, num_heads, num_fields, num_fields)
976
+ attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
977
+ attn_weights = F.softmax(attn_scores, dim=-1)
978
+
979
+ if self.dropout is not None:
980
+ attn_weights = self.dropout(attn_weights)
981
+
982
+ # Apply attention to values
983
+ # (batch_size, num_heads, num_fields, head_dim)
984
+ attn_output = torch.matmul(attn_weights, V)
985
+
986
+ # Concatenate heads
987
+ # (batch_size, num_fields, embed_dim)
988
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, num_fields, embed_dim)
989
+
990
+ # Residual connection
991
+ if self.residual and self.W_Res is not None:
992
+ attn_output = attn_output + self.W_Res(x)
993
+
994
+ return F.relu(attn_output)