torch-rechub 0.0.1__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. torch_rechub/__init__.py +14 -0
  2. torch_rechub/basic/activation.py +3 -1
  3. torch_rechub/basic/callback.py +2 -2
  4. torch_rechub/basic/features.py +38 -8
  5. torch_rechub/basic/initializers.py +92 -0
  6. torch_rechub/basic/layers.py +800 -46
  7. torch_rechub/basic/loss_func.py +223 -0
  8. torch_rechub/basic/metaoptimizer.py +76 -0
  9. torch_rechub/basic/metric.py +251 -0
  10. torch_rechub/models/generative/__init__.py +6 -0
  11. torch_rechub/models/generative/hllm.py +249 -0
  12. torch_rechub/models/generative/hstu.py +189 -0
  13. torch_rechub/models/matching/__init__.py +13 -0
  14. torch_rechub/models/matching/comirec.py +193 -0
  15. torch_rechub/models/matching/dssm.py +72 -0
  16. torch_rechub/models/matching/dssm_facebook.py +77 -0
  17. torch_rechub/models/matching/dssm_senet.py +87 -0
  18. torch_rechub/models/matching/gru4rec.py +85 -0
  19. torch_rechub/models/matching/mind.py +103 -0
  20. torch_rechub/models/matching/narm.py +82 -0
  21. torch_rechub/models/matching/sasrec.py +143 -0
  22. torch_rechub/models/matching/sine.py +148 -0
  23. torch_rechub/models/matching/stamp.py +81 -0
  24. torch_rechub/models/matching/youtube_dnn.py +75 -0
  25. torch_rechub/models/matching/youtube_sbc.py +98 -0
  26. torch_rechub/models/multi_task/__init__.py +5 -2
  27. torch_rechub/models/multi_task/aitm.py +83 -0
  28. torch_rechub/models/multi_task/esmm.py +19 -8
  29. torch_rechub/models/multi_task/mmoe.py +18 -12
  30. torch_rechub/models/multi_task/ple.py +41 -29
  31. torch_rechub/models/multi_task/shared_bottom.py +3 -2
  32. torch_rechub/models/ranking/__init__.py +13 -2
  33. torch_rechub/models/ranking/afm.py +65 -0
  34. torch_rechub/models/ranking/autoint.py +102 -0
  35. torch_rechub/models/ranking/bst.py +61 -0
  36. torch_rechub/models/ranking/dcn.py +38 -0
  37. torch_rechub/models/ranking/dcn_v2.py +59 -0
  38. torch_rechub/models/ranking/deepffm.py +131 -0
  39. torch_rechub/models/ranking/deepfm.py +8 -7
  40. torch_rechub/models/ranking/dien.py +191 -0
  41. torch_rechub/models/ranking/din.py +31 -19
  42. torch_rechub/models/ranking/edcn.py +101 -0
  43. torch_rechub/models/ranking/fibinet.py +42 -0
  44. torch_rechub/models/ranking/widedeep.py +6 -6
  45. torch_rechub/trainers/__init__.py +4 -2
  46. torch_rechub/trainers/ctr_trainer.py +191 -0
  47. torch_rechub/trainers/match_trainer.py +239 -0
  48. torch_rechub/trainers/matching.md +3 -0
  49. torch_rechub/trainers/mtl_trainer.py +137 -23
  50. torch_rechub/trainers/seq_trainer.py +293 -0
  51. torch_rechub/utils/__init__.py +0 -0
  52. torch_rechub/utils/data.py +492 -0
  53. torch_rechub/utils/hstu_utils.py +198 -0
  54. torch_rechub/utils/match.py +457 -0
  55. torch_rechub/utils/mtl.py +136 -0
  56. torch_rechub/utils/onnx_export.py +353 -0
  57. torch_rechub-0.0.4.dist-info/METADATA +391 -0
  58. torch_rechub-0.0.4.dist-info/RECORD +62 -0
  59. {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.4.dist-info}/WHEEL +1 -2
  60. {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.4.dist-info/licenses}/LICENSE +1 -1
  61. torch_rechub/basic/utils.py +0 -168
  62. torch_rechub/trainers/trainer.py +0 -111
  63. torch_rechub-0.0.1.dist-info/METADATA +0 -105
  64. torch_rechub-0.0.1.dist-info/RECORD +0 -26
  65. torch_rechub-0.0.1.dist-info/top_level.txt +0 -1
@@ -1,14 +1,18 @@
1
+ from itertools import combinations
2
+
1
3
  import torch
2
4
  import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
3
7
  from .activation import activation_layer
4
- from .features import DenseFeature, SparseFeature, SequenceFeature
8
+ from .features import DenseFeature, SequenceFeature, SparseFeature
5
9
 
6
10
 
7
11
  class PredictionLayer(nn.Module):
8
12
  """Prediction Layer.
9
13
 
10
14
  Args:
11
- task_type (str): if `task_type='classification'`, then return sigmoid(x),
15
+ task_type (str): if `task_type='classification'`, then return sigmoid(x),
12
16
  change the input logits to probability. if`task_type='regression'`, then return x.
13
17
  """
14
18
 
@@ -25,19 +29,20 @@ class PredictionLayer(nn.Module):
25
29
 
26
30
 
27
31
  class EmbeddingLayer(nn.Module):
28
- """General Embedding Layer. We init each embedding layer by `xavier_normal_`.
29
-
32
+ """General Embedding Layer.
33
+ We save all the feature embeddings in embed_dict: `{feature_name : embedding table}`.
34
+
35
+
30
36
  Args:
31
37
  features (list): the list of `Feature Class`. It is means all the features which we want to create a embedding table.
32
- embed_dict (dict): the embedding dict, `{feature_name : embedding table}`.
33
38
 
34
39
  Shape:
35
- - Input:
40
+ - Input:
36
41
  x (dict): {feature_name: feature_value}, sequence feature value is a 2D tensor with shape:`(batch_size, seq_len)`,\
37
42
  sparse/dense feature value is a 1D tensor with shape `(batch_size)`.
38
43
  features (list): the list of `Feature Class`. It is means the current features which we want to do embedding lookup.
39
- squeeze_dim (bool): whether to squeeze dim of output (default = `True`).
40
- - Output:
44
+ squeeze_dim (bool): whether to squeeze dim of output (default = `False`).
45
+ - Output:
41
46
  - if input Dense: `(batch_size, num_features_dense)`.
42
47
  - if input Sparse: `(batch_size, num_features, embed_dim)` or `(batch_size, num_features * embed_dim)`.
43
48
  - if input Sequence: same with input sparse or `(batch_size, num_features_seq, seq_length, embed_dim)` when `pooling=="concat"`.
@@ -51,23 +56,24 @@ class EmbeddingLayer(nn.Module):
51
56
  self.n_dense = 0
52
57
 
53
58
  for fea in features:
54
- if fea.name in self.embed_dict: #exist
59
+ if fea.name in self.embed_dict: # exist
55
60
  continue
56
- if isinstance(fea, SparseFeature):
57
- self.embed_dict[fea.name] = nn.Embedding(fea.vocab_size, fea.embed_dim)
58
- elif isinstance(fea, SequenceFeature) and fea.shared_with == None:
59
- self.embed_dict[fea.name] = nn.Embedding(fea.vocab_size, fea.embed_dim)
61
+ if isinstance(fea, SparseFeature) and fea.shared_with is None:
62
+ self.embed_dict[fea.name] = fea.get_embedding_layer()
63
+ elif isinstance(fea, SequenceFeature) and fea.shared_with is None:
64
+ self.embed_dict[fea.name] = fea.get_embedding_layer()
60
65
  elif isinstance(fea, DenseFeature):
61
66
  self.n_dense += 1
62
- for matrix in self.embed_dict.values(): #init embedding weight
63
- torch.nn.init.xavier_normal_(matrix.weight)
64
67
 
65
68
  def forward(self, x, features, squeeze_dim=False):
66
69
  sparse_emb, dense_values = [], []
67
70
  sparse_exists, dense_exists = False, False
68
71
  for fea in features:
69
72
  if isinstance(fea, SparseFeature):
70
- sparse_emb.append(self.embed_dict[fea.name](x[fea.name].long()).unsqueeze(1))
73
+ if fea.shared_with is None:
74
+ sparse_emb.append(self.embed_dict[fea.name](x[fea.name].long()).unsqueeze(1))
75
+ else:
76
+ sparse_emb.append(self.embed_dict[fea.shared_with](x[fea.name].long()).unsqueeze(1))
71
77
  elif isinstance(fea, SequenceFeature):
72
78
  if fea.pooling == "sum":
73
79
  pooling_layer = SumPooling()
@@ -77,38 +83,75 @@ class EmbeddingLayer(nn.Module):
77
83
  pooling_layer = ConcatPooling()
78
84
  else:
79
85
  raise ValueError("Sequence pooling method supports only pooling in %s, got %s." % (["sum", "mean"], fea.pooling))
80
- if fea.shared_with == None:
81
- sparse_emb.append(pooling_layer(self.embed_dict[fea.name](x[fea.name].long())).unsqueeze(1))
86
+ fea_mask = InputMask()(x, fea)
87
+ if fea.shared_with is None:
88
+ sparse_emb.append(pooling_layer(self.embed_dict[fea.name](x[fea.name].long()), fea_mask).unsqueeze(1))
82
89
  else:
83
- sparse_emb.append(pooling_layer(self.embed_dict[fea.shared_with](x[fea.name].long())).unsqueeze(1)) #shared specific sparse feature embedding
90
+ sparse_emb.append(pooling_layer(self.embed_dict[fea.shared_with](x[fea.name].long()), fea_mask).unsqueeze(1)) # shared specific sparse feature embedding
84
91
  else:
85
- dense_values.append(x[fea.name].float().unsqueeze(1)) #.unsqueeze(1).unsqueeze(1)
92
+ dense_values.append(x[fea.name].float() if x[fea.name].float().dim() > 1 else x[fea.name].float().unsqueeze(1)) # .unsqueeze(1).unsqueeze(1)
86
93
 
87
94
  if len(dense_values) > 0:
88
95
  dense_exists = True
89
96
  dense_values = torch.cat(dense_values, dim=1)
90
97
  if len(sparse_emb) > 0:
91
98
  sparse_exists = True
92
- sparse_emb = torch.cat(sparse_emb, dim=1) #[batch_size, num_features, embed_dim]
99
+ # TODO: support concat dynamic embed_dim in dim 2
100
+ # [batch_size, num_features, embed_dim]
101
+ sparse_emb = torch.cat(sparse_emb, dim=1)
93
102
 
94
- if squeeze_dim: #Note: if the emb_dim of sparse features is different, we must squeeze_dim
95
- if dense_exists and not sparse_exists: #only input dense features
103
+ if squeeze_dim: # Note: if the emb_dim of sparse features is different, we must squeeze_dim
104
+ if dense_exists and not sparse_exists: # only input dense features
96
105
  return dense_values
97
106
  elif not dense_exists and sparse_exists:
98
- return sparse_emb.flatten(start_dim=1) #squeeze dim to : [batch_size, num_features*embed_dim]
107
+ # squeeze dim to : [batch_size, num_features*embed_dim]
108
+ return sparse_emb.flatten(start_dim=1)
99
109
  elif dense_exists and sparse_exists:
100
- return torch.cat((sparse_emb.flatten(start_dim=1), dense_values), dim=1) #concat dense value with sparse embedding
110
+ # concat dense value with sparse embedding
111
+ return torch.cat((sparse_emb.flatten(start_dim=1), dense_values), dim=1)
101
112
  else:
102
113
  raise ValueError("The input features can note be empty")
103
114
  else:
104
115
  if sparse_exists:
105
- return sparse_emb #[batch_size, num_features, embed_dim]
116
+ return sparse_emb # [batch_size, num_features, embed_dim]
106
117
  else:
107
118
  raise ValueError("If keep the original shape:[batch_size, num_features, embed_dim], expected %s in feature list, got %s" % ("SparseFeatures", features))
108
119
 
109
120
 
121
+ class InputMask(nn.Module):
122
+ """Return inputs mask from given features
123
+
124
+ Shape:
125
+ - Input:
126
+ x (dict): {feature_name: feature_value}, sequence feature value is a 2D tensor with shape:`(batch_size, seq_len)`,\
127
+ sparse/dense feature value is a 1D tensor with shape `(batch_size)`.
128
+ features (list or SparseFeature or SequenceFeature): Note that the elements in features are either all instances of SparseFeature or all instances of SequenceFeature.
129
+ - Output:
130
+ - if input Sparse: `(batch_size, num_features)`
131
+ - if input Sequence: `(batch_size, num_features_seq, seq_length)`
132
+ """
133
+
134
+ def __init__(self):
135
+ super().__init__()
136
+
137
+ def forward(self, x, features):
138
+ mask = []
139
+ if not isinstance(features, list):
140
+ features = [features]
141
+ for fea in features:
142
+ if isinstance(fea, SparseFeature) or isinstance(fea, SequenceFeature):
143
+ if fea.padding_idx is not None:
144
+ fea_mask = x[fea.name].long() != fea.padding_idx
145
+ else:
146
+ fea_mask = x[fea.name].long() != -1
147
+ mask.append(fea_mask.unsqueeze(1).float())
148
+ else:
149
+ raise ValueError("Only SparseFeature or SequenceFeature support to get mask.")
150
+ return torch.cat(mask, dim=1)
151
+
152
+
110
153
  class LR(nn.Module):
111
- """Logistic Regression Module. It is the one Non-linear
154
+ """Logistic Regression Module. It is the one Non-linear
112
155
  transformation for input feature.
113
156
 
114
157
  Args:
@@ -134,7 +177,7 @@ class LR(nn.Module):
134
177
 
135
178
  class ConcatPooling(nn.Module):
136
179
  """Keep the origin sequence embedding shape
137
-
180
+
138
181
  Shape:
139
182
  - Input: `(batch_size, seq_length, embed_dim)`
140
183
  - Output: `(batch_size, seq_length, embed_dim)`
@@ -143,62 +186,73 @@ class ConcatPooling(nn.Module):
143
186
  def __init__(self):
144
187
  super().__init__()
145
188
 
146
- def forward(self, x):
189
+ def forward(self, x, mask=None):
147
190
  return x
148
191
 
149
192
 
150
193
  class AveragePooling(nn.Module):
151
194
  """Pooling the sequence embedding matrix by `mean`.
152
-
195
+
153
196
  Shape:
154
- - Input: `(batch_size, seq_length, embed_dim)`
197
+ - Input
198
+ x: `(batch_size, seq_length, embed_dim)`
199
+ mask: `(batch_size, 1, seq_length)`
155
200
  - Output: `(batch_size, embed_dim)`
156
201
  """
157
202
 
158
203
  def __init__(self):
159
204
  super().__init__()
160
205
 
161
- def forward(self, x):
162
- sum_pooling_matrix = torch.sum(x, dim=1)
163
- non_padding_length = (x != 0).sum(dim=1)
164
- x = sum_pooling_matrix / (non_padding_length.float() + 1e-16)
165
- return x
206
+ def forward(self, x, mask=None):
207
+ if mask is None:
208
+ return torch.mean(x, dim=1)
209
+ else:
210
+ sum_pooling_matrix = torch.bmm(mask, x).squeeze(1)
211
+ non_padding_length = mask.sum(dim=-1)
212
+ return sum_pooling_matrix / (non_padding_length.float() + 1e-16)
166
213
 
167
214
 
168
215
  class SumPooling(nn.Module):
169
216
  """Pooling the sequence embedding matrix by `sum`.
170
217
 
171
218
  Shape:
172
- - Input: `(batch_size, seq_length, embed_dim)`
219
+ - Input
220
+ x: `(batch_size, seq_length, embed_dim)`
221
+ mask: `(batch_size, 1, seq_length)`
173
222
  - Output: `(batch_size, embed_dim)`
174
223
  """
175
224
 
176
225
  def __init__(self):
177
226
  super().__init__()
178
227
 
179
- def forward(self, x):
180
- return torch.sum(x, dim=1)
228
+ def forward(self, x, mask=None):
229
+ if mask is None:
230
+ return torch.sum(x, dim=1)
231
+ else:
232
+ return torch.bmm(mask, x).squeeze(1)
181
233
 
182
234
 
183
235
  class MLP(nn.Module):
184
- """Multi Layer Perceptron Module, it is the most widely used module for
185
- learning feature. Note we default add `BatchNorm1d` and `Activation`
236
+ """Multi Layer Perceptron Module, it is the most widely used module for
237
+ learning feature. Note we default add `BatchNorm1d` and `Activation`
186
238
  `Dropout` for each `Linear` Module.
187
239
 
188
240
  Args:
189
241
  input dim (int): input size of the first Linear Layer.
190
- dims (list): output size of Linear Layer.
242
+ output_layer (bool): whether this MLP module is the output layer. If `True`, then append one Linear(*,1) module.
243
+ dims (list): output size of Linear Layer (default=[]).
191
244
  dropout (float): probability of an element to be zeroed (default = 0.5).
192
245
  activation (str): the activation function, support `[sigmoid, relu, prelu, dice, softmax]` (default='relu').
193
- output_layer (bool): whether this MLP module is the output layer. If `True`, then append one Linear(*,1) module.
194
246
 
195
247
  Shape:
196
248
  - Input: `(batch_size, input_dim)`
197
249
  - Output: `(batch_size, 1)` or `(batch_size, dims[-1])`
198
250
  """
199
251
 
200
- def __init__(self, input_dim, dims, dropout=0, activation="relu", output_layer=True):
252
+ def __init__(self, input_dim, output_layer=True, dims=None, dropout=0, activation="relu"):
201
253
  super().__init__()
254
+ if dims is None:
255
+ dims = []
202
256
  layers = list()
203
257
  for i_dim in dims:
204
258
  layers.append(nn.Linear(input_dim, i_dim))
@@ -216,7 +270,7 @@ class MLP(nn.Module):
216
270
 
217
271
  class FM(nn.Module):
218
272
  """The Factorization Machine module, mentioned in the `DeepFM paper
219
- <https://arxiv.org/pdf/1703.04247.pdf>`. It is used to learn 2nd-order
273
+ <https://arxiv.org/pdf/1703.04247.pdf>`. It is used to learn 2nd-order
220
274
  feature interactions.
221
275
 
222
276
  Args:
@@ -237,4 +291,704 @@ class FM(nn.Module):
237
291
  ix = square_of_sum - sum_of_square
238
292
  if self.reduce_sum:
239
293
  ix = torch.sum(ix, dim=1, keepdim=True)
240
- return 0.5 * ix
294
+ return 0.5 * ix
295
+
296
+
297
+ class CIN(nn.Module):
298
+ """Compressed Interaction Network
299
+
300
+ Args:
301
+ input_dim (int): input dim of input tensor.
302
+ cin_size (list[int]): out channels of Conv1d.
303
+
304
+ Shape:
305
+ - Input: `(batch_size, num_features, embed_dim)`
306
+ - Output: `(batch_size, 1)`
307
+ """
308
+
309
+ def __init__(self, input_dim, cin_size, split_half=True):
310
+ super().__init__()
311
+ self.num_layers = len(cin_size)
312
+ self.split_half = split_half
313
+ self.conv_layers = torch.nn.ModuleList()
314
+ prev_dim, fc_input_dim = input_dim, 0
315
+ for i in range(self.num_layers):
316
+ cross_layer_size = cin_size[i]
317
+ self.conv_layers.append(torch.nn.Conv1d(input_dim * prev_dim, cross_layer_size, 1, stride=1, dilation=1, bias=True))
318
+ if self.split_half and i != self.num_layers - 1:
319
+ cross_layer_size //= 2
320
+ prev_dim = cross_layer_size
321
+ fc_input_dim += prev_dim
322
+ self.fc = torch.nn.Linear(fc_input_dim, 1)
323
+
324
+ def forward(self, x):
325
+ xs = list()
326
+ x0, h = x.unsqueeze(2), x
327
+ for i in range(self.num_layers):
328
+ x = x0 * h.unsqueeze(1)
329
+ batch_size, f0_dim, fin_dim, embed_dim = x.shape
330
+ x = x.view(batch_size, f0_dim * fin_dim, embed_dim)
331
+ x = F.relu(self.conv_layers[i](x))
332
+ if self.split_half and i != self.num_layers - 1:
333
+ x, h = torch.split(x, x.shape[1] // 2, dim=1)
334
+ else:
335
+ h = x
336
+ xs.append(x)
337
+ return self.fc(torch.sum(torch.cat(xs, dim=1), 2))
338
+
339
+
340
+ class CrossLayer(nn.Module):
341
+ """
342
+ Cross layer.
343
+ Args:
344
+ input_dim (int): input dim of input tensor
345
+ """
346
+
347
+ def __init__(self, input_dim):
348
+ super(CrossLayer, self).__init__()
349
+ self.w = torch.nn.Linear(input_dim, 1, bias=False)
350
+ self.b = torch.nn.Parameter(torch.zeros(input_dim))
351
+
352
+ def forward(self, x_0, x_i):
353
+ x = self.w(x_i) * x_0 + self.b
354
+ return x
355
+
356
+
357
+ class CrossNetwork(nn.Module):
358
+ """CrossNetwork mentioned in the DCN paper.
359
+
360
+ Args:
361
+ input_dim (int): input dim of input tensor
362
+
363
+ Shape:
364
+ - Input: `(batch_size, *)`
365
+ - Output: `(batch_size, *)`
366
+
367
+ """
368
+
369
+ def __init__(self, input_dim, num_layers):
370
+ super().__init__()
371
+ self.num_layers = num_layers
372
+ self.w = torch.nn.ModuleList([torch.nn.Linear(input_dim, 1, bias=False) for _ in range(num_layers)])
373
+ self.b = torch.nn.ParameterList([torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)])
374
+
375
+ def forward(self, x):
376
+ """
377
+ :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
378
+ """
379
+ x0 = x
380
+ for i in range(self.num_layers):
381
+ xw = self.w[i](x)
382
+ x = x0 * xw + self.b[i] + x
383
+ return x
384
+
385
+
386
+ class CrossNetV2(nn.Module):
387
+
388
+ def __init__(self, input_dim, num_layers):
389
+ super().__init__()
390
+ self.num_layers = num_layers
391
+ self.w = torch.nn.ModuleList([torch.nn.Linear(input_dim, input_dim, bias=False) for _ in range(num_layers)])
392
+ self.b = torch.nn.ParameterList([torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)])
393
+
394
+ def forward(self, x):
395
+ x0 = x
396
+ for i in range(self.num_layers):
397
+ x = x0 * self.w[i](x) + self.b[i] + x
398
+ return x
399
+
400
+
401
+ class CrossNetMix(nn.Module):
402
+ """ CrossNetMix improves CrossNetwork by:
403
+ 1. add MOE to learn feature interactions in different subspaces
404
+ 2. add nonlinear transformations in low-dimensional space
405
+ :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
406
+ """
407
+
408
+ def __init__(self, input_dim, num_layers=2, low_rank=32, num_experts=4):
409
+ super(CrossNetMix, self).__init__()
410
+ self.num_layers = num_layers
411
+ self.num_experts = num_experts
412
+
413
+ # U: (input_dim, low_rank)
414
+ self.u_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(torch.empty(num_experts, input_dim, low_rank))) for i in range(self.num_layers)])
415
+ # V: (input_dim, low_rank)
416
+ self.v_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(torch.empty(num_experts, input_dim, low_rank))) for i in range(self.num_layers)])
417
+ # C: (low_rank, low_rank)
418
+ self.c_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(torch.empty(num_experts, low_rank, low_rank))) for i in range(self.num_layers)])
419
+ self.gating = nn.ModuleList([nn.Linear(input_dim, 1, bias=False) for i in range(self.num_experts)])
420
+
421
+ self.bias = torch.nn.ParameterList([nn.Parameter(nn.init.zeros_(torch.empty(input_dim, 1))) for i in range(self.num_layers)])
422
+
423
+ def forward(self, x):
424
+ x_0 = x.unsqueeze(2) # (bs, in_features, 1)
425
+ x_l = x_0
426
+ for i in range(self.num_layers):
427
+ output_of_experts = []
428
+ gating_score_experts = []
429
+ for expert_id in range(self.num_experts):
430
+ # (1) G(x_l)
431
+ # compute the gating score by x_l
432
+ gating_score_experts.append(self.gating[expert_id](x_l.squeeze(2)))
433
+
434
+ # (2) E(x_l)
435
+ # project the input x_l to $\mathbb{R}^{r}$
436
+ v_x = torch.matmul(self.v_list[i][expert_id].t(), x_l) # (bs, low_rank, 1)
437
+
438
+ # nonlinear activation in low rank space
439
+ v_x = torch.tanh(v_x)
440
+ v_x = torch.matmul(self.c_list[i][expert_id], v_x)
441
+ v_x = torch.tanh(v_x)
442
+
443
+ # project back to $\mathbb{R}^{d}$
444
+ uv_x = torch.matmul(self.u_list[i][expert_id], v_x) # (bs, in_features, 1)
445
+
446
+ dot_ = uv_x + self.bias[i]
447
+ dot_ = x_0 * dot_ # Hadamard-product
448
+
449
+ output_of_experts.append(dot_.squeeze(2))
450
+
451
+
452
+ # (3) mixture of low-rank experts
453
+ output_of_experts = torch.stack(output_of_experts, 2) # (bs, in_features, num_experts)
454
+ gating_score_experts = torch.stack(gating_score_experts, 1) # (bs, num_experts, 1)
455
+ moe_out = torch.matmul(output_of_experts, gating_score_experts.softmax(1))
456
+ x_l = moe_out + x_l # (bs, in_features, 1)
457
+
458
+ x_l = x_l.squeeze() # (bs, in_features)
459
+ return x_l
460
+
461
+
462
+ class SENETLayer(nn.Module):
463
+ """
464
+ A weighted feature gating system in the SENet paper
465
+ Args:
466
+ num_fields (int): number of feature fields
467
+
468
+ Shape:
469
+ - num_fields: `(batch_size, *)`
470
+ - Output: `(batch_size, *)`
471
+ """
472
+
473
+ def __init__(self, num_fields, reduction_ratio=3):
474
+ super(SENETLayer, self).__init__()
475
+ reduced_size = max(1, int(num_fields / reduction_ratio))
476
+ self.mlp = nn.Sequential(nn.Linear(num_fields, reduced_size, bias=False), nn.ReLU(), nn.Linear(reduced_size, num_fields, bias=False), nn.ReLU())
477
+
478
+ def forward(self, x):
479
+ z = torch.mean(x, dim=-1, out=None)
480
+ a = self.mlp(z)
481
+ v = x * a.unsqueeze(-1)
482
+ return v
483
+
484
+
485
+ class BiLinearInteractionLayer(nn.Module):
486
+ """
487
+ Bilinear feature interaction module, which is an improved model of the FFM model
488
+ Args:
489
+ num_fields (int): number of feature fields
490
+ bilinear_type(str): the type bilinear interaction function
491
+ Shape:
492
+ - num_fields: `(batch_size, *)`
493
+ - Output: `(batch_size, *)`
494
+ """
495
+
496
+ def __init__(self, input_dim, num_fields, bilinear_type="field_interaction"):
497
+ super(BiLinearInteractionLayer, self).__init__()
498
+ self.bilinear_type = bilinear_type
499
+ if self.bilinear_type == "field_all":
500
+ self.bilinear_layer = nn.Linear(input_dim, input_dim, bias=False)
501
+ elif self.bilinear_type == "field_each":
502
+ self.bilinear_layer = nn.ModuleList([nn.Linear(input_dim, input_dim, bias=False) for i in range(num_fields)])
503
+ elif self.bilinear_type == "field_interaction":
504
+ self.bilinear_layer = nn.ModuleList([nn.Linear(input_dim, input_dim, bias=False) for i, j in combinations(range(num_fields), 2)])
505
+ else:
506
+ raise NotImplementedError()
507
+
508
+ def forward(self, x):
509
+ feature_emb = torch.split(x, 1, dim=1)
510
+ if self.bilinear_type == "field_all":
511
+ bilinear_list = [self.bilinear_layer(v_i) * v_j for v_i, v_j in combinations(feature_emb, 2)]
512
+ elif self.bilinear_type == "field_each":
513
+ bilinear_list = [self.bilinear_layer[i](feature_emb[i]) * feature_emb[j] for i, j in combinations(range(len(feature_emb)), 2)]
514
+ elif self.bilinear_type == "field_interaction":
515
+ bilinear_list = [self.bilinear_layer[i](v[0]) * v[1] for i, v in enumerate(combinations(feature_emb, 2))]
516
+ return torch.cat(bilinear_list, dim=1)
517
+
518
+
519
+ class MultiInterestSA(nn.Module):
520
+ """MultiInterest Attention mentioned in the Comirec paper.
521
+
522
+ Args:
523
+ embedding_dim (int): embedding dim of item embedding
524
+ interest_num (int): num of interest
525
+ hidden_dim (int): hidden dim
526
+
527
+ Shape:
528
+ - Input: seq_emb : (batch,seq,emb)
529
+ mask : (batch,seq,1)
530
+ - Output: `(batch_size, interest_num, embedding_dim)`
531
+
532
+ """
533
+
534
+ def __init__(self, embedding_dim, interest_num, hidden_dim=None):
535
+ super(MultiInterestSA, self).__init__()
536
+ self.embedding_dim = embedding_dim
537
+ self.interest_num = interest_num
538
+ if hidden_dim is None:
539
+ self.hidden_dim = self.embedding_dim * 4
540
+ self.W1 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.hidden_dim), requires_grad=True)
541
+ self.W2 = torch.nn.Parameter(torch.rand(self.hidden_dim, self.interest_num), requires_grad=True)
542
+ self.W3 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.embedding_dim), requires_grad=True)
543
+
544
+ def forward(self, seq_emb, mask=None):
545
+ H = torch.einsum('bse, ed -> bsd', seq_emb, self.W1).tanh()
546
+ if mask is not None:
547
+ A = torch.einsum('bsd, dk -> bsk', H, self.W2) + - \
548
+ 1.e9 * (1 - mask.float())
549
+ A = F.softmax(A, dim=1)
550
+ else:
551
+ A = F.softmax(torch.einsum('bsd, dk -> bsk', H, self.W2), dim=1)
552
+ A = A.permute(0, 2, 1)
553
+ multi_interest_emb = torch.matmul(A, seq_emb)
554
+ return multi_interest_emb
555
+
556
+
557
+ class CapsuleNetwork(nn.Module):
558
+ """CapsuleNetwork mentioned in the Comirec and MIND paper.
559
+
560
+ Args:
561
+ hidden_size (int): embedding dim of item embedding
562
+ seq_len (int): length of the item sequence
563
+ bilinear_type (int): 0 for MIND, 2 for ComirecDR
564
+ interest_num (int): num of interest
565
+ routing_times (int): routing times
566
+
567
+ Shape:
568
+ - Input: seq_emb : (batch,seq,emb)
569
+ mask : (batch,seq,1)
570
+ - Output: `(batch_size, interest_num, embedding_dim)`
571
+
572
+ """
573
+
574
+ def __init__(self, embedding_dim, seq_len, bilinear_type=2, interest_num=4, routing_times=3, relu_layer=False):
575
+ super(CapsuleNetwork, self).__init__()
576
+ self.embedding_dim = embedding_dim # h
577
+ self.seq_len = seq_len # s
578
+ self.bilinear_type = bilinear_type
579
+ self.interest_num = interest_num
580
+ self.routing_times = routing_times
581
+
582
+ self.relu_layer = relu_layer
583
+ self.stop_grad = True
584
+ self.relu = nn.Sequential(nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), nn.ReLU())
585
+ if self.bilinear_type == 0: # MIND
586
+ self.linear = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
587
+ elif self.bilinear_type == 1:
588
+ self.linear = nn.Linear(self.embedding_dim, self.embedding_dim * self.interest_num, bias=False)
589
+ else:
590
+ self.w = nn.Parameter(torch.Tensor(1, self.seq_len, self.interest_num * self.embedding_dim, self.embedding_dim))
591
+
592
+ def forward(self, item_eb, mask):
593
+ if self.bilinear_type == 0:
594
+ item_eb_hat = self.linear(item_eb)
595
+ item_eb_hat = item_eb_hat.repeat(1, 1, self.interest_num)
596
+ elif self.bilinear_type == 1:
597
+ item_eb_hat = self.linear(item_eb)
598
+ else:
599
+ u = torch.unsqueeze(item_eb, dim=2)
600
+ item_eb_hat = torch.sum(self.w[:, :self.seq_len, :, :] * u, dim=3)
601
+
602
+ item_eb_hat = torch.reshape(item_eb_hat, (-1, self.seq_len, self.interest_num, self.embedding_dim))
603
+ item_eb_hat = torch.transpose(item_eb_hat, 1, 2).contiguous()
604
+ item_eb_hat = torch.reshape(item_eb_hat, (-1, self.interest_num, self.seq_len, self.embedding_dim))
605
+
606
+ if self.stop_grad:
607
+ item_eb_hat_iter = item_eb_hat.detach()
608
+ else:
609
+ item_eb_hat_iter = item_eb_hat
610
+
611
+ if self.bilinear_type > 0:
612
+ capsule_weight = torch.zeros(item_eb_hat.shape[0], self.interest_num, self.seq_len, device=item_eb.device, requires_grad=False)
613
+ else:
614
+ capsule_weight = torch.randn(item_eb_hat.shape[0], self.interest_num, self.seq_len, device=item_eb.device, requires_grad=False)
615
+
616
+ for i in range(self.routing_times): # 动态路由传播3次
617
+ atten_mask = torch.unsqueeze(mask, 1).repeat(1, self.interest_num, 1)
618
+ paddings = torch.zeros_like(atten_mask, dtype=torch.float)
619
+
620
+ capsule_softmax_weight = F.softmax(capsule_weight, dim=-1)
621
+ capsule_softmax_weight = torch.where(torch.eq(atten_mask, 0), paddings, capsule_softmax_weight)
622
+ capsule_softmax_weight = torch.unsqueeze(capsule_softmax_weight, 2)
623
+
624
+ if i < 2:
625
+ interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat_iter)
626
+ cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
627
+ scalar_factor = cap_norm / \
628
+ (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
629
+ interest_capsule = scalar_factor * interest_capsule
630
+
631
+ delta_weight = torch.matmul(item_eb_hat_iter, torch.transpose(interest_capsule, 2, 3).contiguous())
632
+ delta_weight = torch.reshape(delta_weight, (-1, self.interest_num, self.seq_len))
633
+ capsule_weight = capsule_weight + delta_weight
634
+ else:
635
+ interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat)
636
+ cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
637
+ scalar_factor = cap_norm / \
638
+ (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
639
+ interest_capsule = scalar_factor * interest_capsule
640
+
641
+ interest_capsule = torch.reshape(interest_capsule, (-1, self.interest_num, self.embedding_dim))
642
+
643
+ if self.relu_layer:
644
+ interest_capsule = self.relu(interest_capsule)
645
+
646
+ return interest_capsule
647
+
648
+
649
+ class FFM(nn.Module):
650
+ """The Field-aware Factorization Machine module, mentioned in the `FFM paper
651
+ <https://dl.acm.org/doi/abs/10.1145/2959100.2959134>`. It explicitly models
652
+ multi-channel second-order feature interactions, with each feature filed
653
+ corresponding to one channel.
654
+
655
+ Args:
656
+ num_fields (int): number of feature fields.
657
+ reduce_sum (bool): whether to sum in embed_dim (default = `True`).
658
+
659
+ Shape:
660
+ - Input: `(batch_size, num_fields, num_fields, embed_dim)`
661
+ - Output: `(batch_size, num_fields*(num_fields-1)/2, 1)` or `(batch_size, num_fields*(num_fields-1)/2, embed_dim)`
662
+ """
663
+
664
+ def __init__(self, num_fields, reduce_sum=True):
665
+ super().__init__()
666
+ self.num_fields = num_fields
667
+ self.reduce_sum = reduce_sum
668
+
669
+ def forward(self, x):
670
+ # compute (non-redundant) second order field-aware feature crossings
671
+ crossed_embeddings = []
672
+ for i in range(self.num_fields - 1):
673
+ for j in range(i + 1, self.num_fields):
674
+ crossed_embeddings.append(x[:, i, j, :] * x[:, j, i, :])
675
+ crossed_embeddings = torch.stack(crossed_embeddings, dim=1)
676
+
677
+ # if reduce_sum is true, the crossing operation is effectively inner
678
+ # product, other wise Hadamard-product
679
+ if self.reduce_sum:
680
+ crossed_embeddings = torch.sum(crossed_embeddings, dim=-1, keepdim=True)
681
+ return crossed_embeddings
682
+
683
+
684
+ class CEN(nn.Module):
685
+ """The Compose-Excitation Network module, mentioned in the `FAT-DeepFFM paper
686
+ <https://arxiv.org/abs/1905.06336>`, a modified version of
687
+ `Squeeze-and-Excitation Network” (SENet) (Hu et al., 2017)`. It is used to
688
+ highlight the importance of second-order feature crosses.
689
+
690
+ Args:
691
+ embed_dim (int): the dimensionality of categorical value embedding.
692
+ num_field_crosses (int): the number of second order crosses between feature fields.
693
+ reduction_ratio (int): the between the dimensions of input layer and hidden layer of the MLP module.
694
+
695
+ Shape:
696
+ - Input: `(batch_size, num_fields, num_fields, embed_dim)`
697
+ - Output: `(batch_size, num_fields*(num_fields-1)/2 * embed_dim)`
698
+ """
699
+
700
+ def __init__(self, embed_dim, num_field_crosses, reduction_ratio):
701
+ super().__init__()
702
+
703
+ # convolution weight (Eq.7 FAT-DeepFFM)
704
+ self.u = torch.nn.Parameter(torch.rand(num_field_crosses, embed_dim), requires_grad=True)
705
+
706
+ # two FC layers that computes the field attention
707
+ self.mlp_att = MLP(num_field_crosses, dims=[num_field_crosses // reduction_ratio, num_field_crosses], output_layer=False, activation="relu")
708
+
709
+ def forward(self, em):
710
+ # compute descriptor vector (Eq.7 FAT-DeepFFM), output shape
711
+ # [batch_size, num_field_crosses]
712
+ d = F.relu((self.u.squeeze(0) * em).sum(-1))
713
+
714
+ # compute field attention (Eq.9), output shape [batch_size,
715
+ # num_field_crosses]
716
+ s = self.mlp_att(d)
717
+
718
+ # rescale original embedding with field attention (Eq.10), output shape
719
+ # [batch_size, num_field_crosses, embed_dim]
720
+ aem = s.unsqueeze(-1) * em
721
+ return aem.flatten(start_dim=1)
722
+
723
+
724
+ # ============ HSTU Layers (新增) ============
725
+
726
+
727
+ class HSTULayer(nn.Module):
728
+ """Single HSTU layer.
729
+
730
+ This layer implements the core HSTU "sequential transduction unit": a
731
+ multi-head self-attention block with gating and a position-wise FFN, plus
732
+ residual connections and LayerNorm.
733
+
734
+ Args:
735
+ d_model (int): Hidden dimension of the model. Default: 512.
736
+ n_heads (int): Number of attention heads. Default: 8.
737
+ dqk (int): Dimension of query/key per head. Default: 64.
738
+ dv (int): Dimension of value per head. Default: 64.
739
+ dropout (float): Dropout rate applied in the layer. Default: 0.1.
740
+ use_rel_pos_bias (bool): Whether to use relative position bias.
741
+
742
+ Shape:
743
+ - Input: ``(batch_size, seq_len, d_model)``
744
+ - Output: ``(batch_size, seq_len, d_model)``
745
+
746
+ Example:
747
+ >>> layer = HSTULayer(d_model=512, n_heads=8)
748
+ >>> x = torch.randn(32, 256, 512)
749
+ >>> output = layer(x)
750
+ >>> output.shape
751
+ torch.Size([32, 256, 512])
752
+ """
753
+
754
+ def __init__(self, d_model=512, n_heads=8, dqk=64, dv=64, dropout=0.1, use_rel_pos_bias=True):
755
+ super().__init__()
756
+ self.d_model = d_model
757
+ self.n_heads = n_heads
758
+ self.dqk = dqk
759
+ self.dv = dv
760
+ self.dropout_rate = dropout
761
+ self.use_rel_pos_bias = use_rel_pos_bias
762
+
763
+ # Validate dimensions
764
+ assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
765
+
766
+ # Projection 1: d_model -> 2*n_heads*dqk + 2*n_heads*dv
767
+ proj1_out_dim = 2 * n_heads * dqk + 2 * n_heads * dv
768
+ self.proj1 = nn.Linear(d_model, proj1_out_dim)
769
+
770
+ # Projection 2: n_heads*dv -> d_model
771
+ self.proj2 = nn.Linear(n_heads * dv, d_model)
772
+
773
+ # Feed-forward network (FFN)
774
+ # Standard Transformer uses 4*d_model as the hidden dimension of FFN
775
+ ffn_hidden_dim = 4 * d_model
776
+ self.ffn = nn.Sequential(nn.Linear(d_model, ffn_hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(ffn_hidden_dim, d_model), nn.Dropout(dropout))
777
+
778
+ # Layer normalization
779
+ self.norm1 = nn.LayerNorm(d_model)
780
+ self.norm2 = nn.LayerNorm(d_model)
781
+
782
+ # Dropout
783
+ self.dropout = nn.Dropout(dropout)
784
+
785
+ # Scaling factor for attention scores
786
+ self.scale = 1.0 / (dqk**0.5)
787
+
788
+ def forward(self, x, rel_pos_bias=None):
789
+ """Forward pass of a single HSTU layer.
790
+
791
+ Args:
792
+ x (Tensor): Input tensor of shape ``(batch_size, seq_len, d_model)``.
793
+ rel_pos_bias (Tensor, optional): Relative position bias of shape
794
+ ``(1, n_heads, seq_len, seq_len)``.
795
+
796
+ Returns:
797
+ Tensor: Output tensor of shape ``(batch_size, seq_len, d_model)``.
798
+ """
799
+ batch_size, seq_len, _ = x.shape
800
+
801
+ # Residual connection
802
+ residual = x
803
+
804
+ # Layer normalization
805
+ x = self.norm1(x)
806
+
807
+ # Projection 1: (B, L, D) -> (B, L, 2*H*dqk + 2*H*dv)
808
+ proj_out = self.proj1(x)
809
+
810
+ # Split into Q, K, U, V
811
+ # Q, K: (B, L, H, dqk)
812
+ # U, V: (B, L, H, dv)
813
+ q = proj_out[..., :self.n_heads * self.dqk].reshape(batch_size, seq_len, self.n_heads, self.dqk)
814
+ k = proj_out[..., self.n_heads * self.dqk:2 * self.n_heads * self.dqk].reshape(batch_size, seq_len, self.n_heads, self.dqk)
815
+ u = proj_out[..., 2 * self.n_heads * self.dqk:2 * self.n_heads * self.dqk + self.n_heads * self.dv].reshape(batch_size, seq_len, self.n_heads, self.dv)
816
+ v = proj_out[..., 2 * self.n_heads * self.dqk + self.n_heads * self.dv:].reshape(batch_size, seq_len, self.n_heads, self.dv)
817
+
818
+ # Transpose to (B, H, L, dqk/dv)
819
+ q = q.transpose(1, 2) # (B, H, L, dqk)
820
+ k = k.transpose(1, 2) # (B, H, L, dqk)
821
+ u = u.transpose(1, 2) # (B, H, L, dv)
822
+ v = v.transpose(1, 2) # (B, H, L, dv)
823
+
824
+ # Compute attention scores: (B, H, L, L)
825
+ scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
826
+
827
+ # Add causal mask (prevent attending to future positions)
828
+ # For generative models this is required so that position i only attends
829
+ # to positions <= i.
830
+ causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool))
831
+ scores = scores.masked_fill(~causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
832
+
833
+ # Add relative position bias if provided
834
+ if rel_pos_bias is not None:
835
+ scores = scores + rel_pos_bias
836
+
837
+ # Softmax over attention scores
838
+ attn_weights = F.softmax(scores, dim=-1)
839
+ attn_weights = self.dropout(attn_weights)
840
+
841
+ # Attention output: (B, H, L, dv)
842
+ attn_output = torch.matmul(attn_weights, v)
843
+
844
+ # Gating mechanism: apply a learned gate on top of attention output
845
+ # First transpose back to (B, L, H, dv)
846
+ attn_output = attn_output.transpose(1, 2) # (B, L, H, dv)
847
+ u = u.transpose(1, 2) # (B, L, H, dv)
848
+
849
+ # Apply element-wise gate: (B, L, H, dv)
850
+ gated_output = attn_output * torch.sigmoid(u)
851
+
852
+ # Merge heads: (B, L, H*dv)
853
+ gated_output = gated_output.reshape(batch_size, seq_len, self.n_heads * self.dv)
854
+
855
+ # Projection 2: (B, L, H*dv) -> (B, L, D)
856
+ output = self.proj2(gated_output)
857
+ output = self.dropout(output)
858
+
859
+ # Residual connection
860
+ output = output + residual
861
+
862
+ # Second residual block: LayerNorm + FFN + residual connection
863
+ residual = output
864
+ output = self.norm2(output)
865
+ output = self.ffn(output)
866
+ output = output + residual
867
+
868
+ return output
869
+
870
+
871
+ class HSTUBlock(nn.Module):
872
+ """Stacked HSTU block.
873
+
874
+ This block stacks multiple :class:`HSTULayer` layers to form a deep HSTU
875
+ encoder for sequential recommendation.
876
+
877
+ Args:
878
+ d_model (int): Hidden dimension of the model. Default: 512.
879
+ n_heads (int): Number of attention heads. Default: 8.
880
+ n_layers (int): Number of stacked HSTU layers. Default: 4.
881
+ dqk (int): Dimension of query/key per head. Default: 64.
882
+ dv (int): Dimension of value per head. Default: 64.
883
+ dropout (float): Dropout rate applied in each layer. Default: 0.1.
884
+ use_rel_pos_bias (bool): Whether to use relative position bias.
885
+
886
+ Shape:
887
+ - Input: ``(batch_size, seq_len, d_model)``
888
+ - Output: ``(batch_size, seq_len, d_model)``
889
+
890
+ Example:
891
+ >>> block = HSTUBlock(d_model=512, n_heads=8, n_layers=4)
892
+ >>> x = torch.randn(32, 256, 512)
893
+ >>> output = block(x)
894
+ >>> output.shape
895
+ torch.Size([32, 256, 512])
896
+ """
897
+
898
+ def __init__(self, d_model=512, n_heads=8, n_layers=4, dqk=64, dv=64, dropout=0.1, use_rel_pos_bias=True):
899
+ super().__init__()
900
+ self.d_model = d_model
901
+ self.n_heads = n_heads
902
+ self.n_layers = n_layers
903
+
904
+ # Create a stack of HSTULayer modules
905
+ self.layers = nn.ModuleList([HSTULayer(d_model=d_model, n_heads=n_heads, dqk=dqk, dv=dv, dropout=dropout, use_rel_pos_bias=use_rel_pos_bias) for _ in range(n_layers)])
906
+
907
+ def forward(self, x, rel_pos_bias=None):
908
+ """Forward pass through all stacked HSTULayer modules.
909
+
910
+ Args:
911
+ x (Tensor): Input tensor of shape ``(batch_size, seq_len, d_model)``.
912
+ rel_pos_bias (Tensor, optional): Relative position bias shared across
913
+ all layers.
914
+
915
+ Returns:
916
+ Tensor: Output tensor of shape ``(batch_size, seq_len, d_model)``.
917
+ """
918
+ for layer in self.layers:
919
+ x = layer(x, rel_pos_bias=rel_pos_bias)
920
+ return x
921
+
922
+
923
+ class InteractingLayer(nn.Module):
924
+ """Multi-head Self-Attention based Interacting Layer, used in AutoInt model.
925
+
926
+ Args:
927
+ embed_dim (int): the embedding dimension.
928
+ num_heads (int): the number of attention heads (default=2).
929
+ dropout (float): the dropout rate (default=0.0).
930
+ residual (bool): whether to use residual connection (default=True).
931
+
932
+ Shape:
933
+ - Input: `(batch_size, num_fields, embed_dim)`
934
+ - Output: `(batch_size, num_fields, embed_dim)`
935
+ """
936
+
937
+ def __init__(self, embed_dim, num_heads=2, dropout=0.0, residual=True):
938
+ super().__init__()
939
+ if embed_dim % num_heads != 0:
940
+ raise ValueError("embed_dim must be divisible by num_heads")
941
+
942
+ self.embed_dim = embed_dim
943
+ self.num_heads = num_heads
944
+ self.head_dim = embed_dim // num_heads
945
+ self.scale = self.head_dim**-0.5
946
+ self.residual = residual
947
+
948
+ self.W_Q = nn.Linear(embed_dim, embed_dim, bias=False)
949
+ self.W_K = nn.Linear(embed_dim, embed_dim, bias=False)
950
+ self.W_V = nn.Linear(embed_dim, embed_dim, bias=False)
951
+
952
+ # Residual connection
953
+ self.W_Res = nn.Linear(embed_dim, embed_dim, bias=False) if residual else None
954
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else None
955
+
956
+ def forward(self, x):
957
+ """
958
+ Args:
959
+ x: input tensor with shape (batch_size, num_fields, embed_dim)
960
+ """
961
+ batch_size, num_fields, embed_dim = x.shape
962
+
963
+ # Linear projections
964
+ Q = self.W_Q(x) # (batch_size, num_fields, embed_dim)
965
+ K = self.W_K(x) # (batch_size, num_fields, embed_dim)
966
+ V = self.W_V(x) # (batch_size, num_fields, embed_dim)
967
+
968
+ # Reshape for multi-head attention
969
+ # (batch_size, num_heads, num_fields, head_dim)
970
+ Q = Q.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
971
+ K = K.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
972
+ V = V.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
973
+
974
+ # Scaled dot-product attention
975
+ # (batch_size, num_heads, num_fields, num_fields)
976
+ attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
977
+ attn_weights = F.softmax(attn_scores, dim=-1)
978
+
979
+ if self.dropout is not None:
980
+ attn_weights = self.dropout(attn_weights)
981
+
982
+ # Apply attention to values
983
+ # (batch_size, num_heads, num_fields, head_dim)
984
+ attn_output = torch.matmul(attn_weights, V)
985
+
986
+ # Concatenate heads
987
+ # (batch_size, num_fields, embed_dim)
988
+ attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, num_fields, embed_dim)
989
+
990
+ # Residual connection
991
+ if self.residual and self.W_Res is not None:
992
+ attn_output = attn_output + self.W_Res(x)
993
+
994
+ return F.relu(attn_output)