torch-rechub 0.0.1__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/__init__.py +14 -0
- torch_rechub/basic/activation.py +3 -1
- torch_rechub/basic/callback.py +2 -2
- torch_rechub/basic/features.py +38 -8
- torch_rechub/basic/initializers.py +92 -0
- torch_rechub/basic/layers.py +800 -46
- torch_rechub/basic/loss_func.py +223 -0
- torch_rechub/basic/metaoptimizer.py +76 -0
- torch_rechub/basic/metric.py +251 -0
- torch_rechub/models/generative/__init__.py +6 -0
- torch_rechub/models/generative/hllm.py +249 -0
- torch_rechub/models/generative/hstu.py +189 -0
- torch_rechub/models/matching/__init__.py +13 -0
- torch_rechub/models/matching/comirec.py +193 -0
- torch_rechub/models/matching/dssm.py +72 -0
- torch_rechub/models/matching/dssm_facebook.py +77 -0
- torch_rechub/models/matching/dssm_senet.py +87 -0
- torch_rechub/models/matching/gru4rec.py +85 -0
- torch_rechub/models/matching/mind.py +103 -0
- torch_rechub/models/matching/narm.py +82 -0
- torch_rechub/models/matching/sasrec.py +143 -0
- torch_rechub/models/matching/sine.py +148 -0
- torch_rechub/models/matching/stamp.py +81 -0
- torch_rechub/models/matching/youtube_dnn.py +75 -0
- torch_rechub/models/matching/youtube_sbc.py +98 -0
- torch_rechub/models/multi_task/__init__.py +5 -2
- torch_rechub/models/multi_task/aitm.py +83 -0
- torch_rechub/models/multi_task/esmm.py +19 -8
- torch_rechub/models/multi_task/mmoe.py +18 -12
- torch_rechub/models/multi_task/ple.py +41 -29
- torch_rechub/models/multi_task/shared_bottom.py +3 -2
- torch_rechub/models/ranking/__init__.py +13 -2
- torch_rechub/models/ranking/afm.py +65 -0
- torch_rechub/models/ranking/autoint.py +102 -0
- torch_rechub/models/ranking/bst.py +61 -0
- torch_rechub/models/ranking/dcn.py +38 -0
- torch_rechub/models/ranking/dcn_v2.py +59 -0
- torch_rechub/models/ranking/deepffm.py +131 -0
- torch_rechub/models/ranking/deepfm.py +8 -7
- torch_rechub/models/ranking/dien.py +191 -0
- torch_rechub/models/ranking/din.py +31 -19
- torch_rechub/models/ranking/edcn.py +101 -0
- torch_rechub/models/ranking/fibinet.py +42 -0
- torch_rechub/models/ranking/widedeep.py +6 -6
- torch_rechub/trainers/__init__.py +4 -2
- torch_rechub/trainers/ctr_trainer.py +191 -0
- torch_rechub/trainers/match_trainer.py +239 -0
- torch_rechub/trainers/matching.md +3 -0
- torch_rechub/trainers/mtl_trainer.py +137 -23
- torch_rechub/trainers/seq_trainer.py +293 -0
- torch_rechub/utils/__init__.py +0 -0
- torch_rechub/utils/data.py +492 -0
- torch_rechub/utils/hstu_utils.py +198 -0
- torch_rechub/utils/match.py +457 -0
- torch_rechub/utils/mtl.py +136 -0
- torch_rechub/utils/onnx_export.py +353 -0
- torch_rechub-0.0.4.dist-info/METADATA +391 -0
- torch_rechub-0.0.4.dist-info/RECORD +62 -0
- {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.4.dist-info}/WHEEL +1 -2
- {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.4.dist-info/licenses}/LICENSE +1 -1
- torch_rechub/basic/utils.py +0 -168
- torch_rechub/trainers/trainer.py +0 -111
- torch_rechub-0.0.1.dist-info/METADATA +0 -105
- torch_rechub-0.0.1.dist-info/RECORD +0 -26
- torch_rechub-0.0.1.dist-info/top_level.txt +0 -1
torch_rechub/basic/layers.py
CHANGED
|
@@ -1,14 +1,18 @@
|
|
|
1
|
+
from itertools import combinations
|
|
2
|
+
|
|
1
3
|
import torch
|
|
2
4
|
import torch.nn as nn
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
|
|
3
7
|
from .activation import activation_layer
|
|
4
|
-
from .features import DenseFeature,
|
|
8
|
+
from .features import DenseFeature, SequenceFeature, SparseFeature
|
|
5
9
|
|
|
6
10
|
|
|
7
11
|
class PredictionLayer(nn.Module):
|
|
8
12
|
"""Prediction Layer.
|
|
9
13
|
|
|
10
14
|
Args:
|
|
11
|
-
task_type (str): if `task_type='classification'`, then return sigmoid(x),
|
|
15
|
+
task_type (str): if `task_type='classification'`, then return sigmoid(x),
|
|
12
16
|
change the input logits to probability. if`task_type='regression'`, then return x.
|
|
13
17
|
"""
|
|
14
18
|
|
|
@@ -25,19 +29,20 @@ class PredictionLayer(nn.Module):
|
|
|
25
29
|
|
|
26
30
|
|
|
27
31
|
class EmbeddingLayer(nn.Module):
|
|
28
|
-
"""General Embedding Layer.
|
|
29
|
-
|
|
32
|
+
"""General Embedding Layer.
|
|
33
|
+
We save all the feature embeddings in embed_dict: `{feature_name : embedding table}`.
|
|
34
|
+
|
|
35
|
+
|
|
30
36
|
Args:
|
|
31
37
|
features (list): the list of `Feature Class`. It is means all the features which we want to create a embedding table.
|
|
32
|
-
embed_dict (dict): the embedding dict, `{feature_name : embedding table}`.
|
|
33
38
|
|
|
34
39
|
Shape:
|
|
35
|
-
- Input:
|
|
40
|
+
- Input:
|
|
36
41
|
x (dict): {feature_name: feature_value}, sequence feature value is a 2D tensor with shape:`(batch_size, seq_len)`,\
|
|
37
42
|
sparse/dense feature value is a 1D tensor with shape `(batch_size)`.
|
|
38
43
|
features (list): the list of `Feature Class`. It is means the current features which we want to do embedding lookup.
|
|
39
|
-
squeeze_dim (bool): whether to squeeze dim of output (default = `
|
|
40
|
-
- Output:
|
|
44
|
+
squeeze_dim (bool): whether to squeeze dim of output (default = `False`).
|
|
45
|
+
- Output:
|
|
41
46
|
- if input Dense: `(batch_size, num_features_dense)`.
|
|
42
47
|
- if input Sparse: `(batch_size, num_features, embed_dim)` or `(batch_size, num_features * embed_dim)`.
|
|
43
48
|
- if input Sequence: same with input sparse or `(batch_size, num_features_seq, seq_length, embed_dim)` when `pooling=="concat"`.
|
|
@@ -51,23 +56,24 @@ class EmbeddingLayer(nn.Module):
|
|
|
51
56
|
self.n_dense = 0
|
|
52
57
|
|
|
53
58
|
for fea in features:
|
|
54
|
-
if fea.name in self.embed_dict: #exist
|
|
59
|
+
if fea.name in self.embed_dict: # exist
|
|
55
60
|
continue
|
|
56
|
-
if isinstance(fea, SparseFeature):
|
|
57
|
-
self.embed_dict[fea.name] =
|
|
58
|
-
elif isinstance(fea, SequenceFeature) and fea.shared_with
|
|
59
|
-
self.embed_dict[fea.name] =
|
|
61
|
+
if isinstance(fea, SparseFeature) and fea.shared_with is None:
|
|
62
|
+
self.embed_dict[fea.name] = fea.get_embedding_layer()
|
|
63
|
+
elif isinstance(fea, SequenceFeature) and fea.shared_with is None:
|
|
64
|
+
self.embed_dict[fea.name] = fea.get_embedding_layer()
|
|
60
65
|
elif isinstance(fea, DenseFeature):
|
|
61
66
|
self.n_dense += 1
|
|
62
|
-
for matrix in self.embed_dict.values(): #init embedding weight
|
|
63
|
-
torch.nn.init.xavier_normal_(matrix.weight)
|
|
64
67
|
|
|
65
68
|
def forward(self, x, features, squeeze_dim=False):
|
|
66
69
|
sparse_emb, dense_values = [], []
|
|
67
70
|
sparse_exists, dense_exists = False, False
|
|
68
71
|
for fea in features:
|
|
69
72
|
if isinstance(fea, SparseFeature):
|
|
70
|
-
|
|
73
|
+
if fea.shared_with is None:
|
|
74
|
+
sparse_emb.append(self.embed_dict[fea.name](x[fea.name].long()).unsqueeze(1))
|
|
75
|
+
else:
|
|
76
|
+
sparse_emb.append(self.embed_dict[fea.shared_with](x[fea.name].long()).unsqueeze(1))
|
|
71
77
|
elif isinstance(fea, SequenceFeature):
|
|
72
78
|
if fea.pooling == "sum":
|
|
73
79
|
pooling_layer = SumPooling()
|
|
@@ -77,38 +83,75 @@ class EmbeddingLayer(nn.Module):
|
|
|
77
83
|
pooling_layer = ConcatPooling()
|
|
78
84
|
else:
|
|
79
85
|
raise ValueError("Sequence pooling method supports only pooling in %s, got %s." % (["sum", "mean"], fea.pooling))
|
|
80
|
-
|
|
81
|
-
|
|
86
|
+
fea_mask = InputMask()(x, fea)
|
|
87
|
+
if fea.shared_with is None:
|
|
88
|
+
sparse_emb.append(pooling_layer(self.embed_dict[fea.name](x[fea.name].long()), fea_mask).unsqueeze(1))
|
|
82
89
|
else:
|
|
83
|
-
sparse_emb.append(pooling_layer(self.embed_dict[fea.shared_with](x[fea.name].long())).unsqueeze(1)) #shared specific sparse feature embedding
|
|
90
|
+
sparse_emb.append(pooling_layer(self.embed_dict[fea.shared_with](x[fea.name].long()), fea_mask).unsqueeze(1)) # shared specific sparse feature embedding
|
|
84
91
|
else:
|
|
85
|
-
dense_values.append(x[fea.name].float().unsqueeze(1))
|
|
92
|
+
dense_values.append(x[fea.name].float() if x[fea.name].float().dim() > 1 else x[fea.name].float().unsqueeze(1)) # .unsqueeze(1).unsqueeze(1)
|
|
86
93
|
|
|
87
94
|
if len(dense_values) > 0:
|
|
88
95
|
dense_exists = True
|
|
89
96
|
dense_values = torch.cat(dense_values, dim=1)
|
|
90
97
|
if len(sparse_emb) > 0:
|
|
91
98
|
sparse_exists = True
|
|
92
|
-
|
|
99
|
+
# TODO: support concat dynamic embed_dim in dim 2
|
|
100
|
+
# [batch_size, num_features, embed_dim]
|
|
101
|
+
sparse_emb = torch.cat(sparse_emb, dim=1)
|
|
93
102
|
|
|
94
|
-
if squeeze_dim: #Note: if the emb_dim of sparse features is different, we must squeeze_dim
|
|
95
|
-
if dense_exists and not sparse_exists: #only input dense features
|
|
103
|
+
if squeeze_dim: # Note: if the emb_dim of sparse features is different, we must squeeze_dim
|
|
104
|
+
if dense_exists and not sparse_exists: # only input dense features
|
|
96
105
|
return dense_values
|
|
97
106
|
elif not dense_exists and sparse_exists:
|
|
98
|
-
|
|
107
|
+
# squeeze dim to : [batch_size, num_features*embed_dim]
|
|
108
|
+
return sparse_emb.flatten(start_dim=1)
|
|
99
109
|
elif dense_exists and sparse_exists:
|
|
100
|
-
|
|
110
|
+
# concat dense value with sparse embedding
|
|
111
|
+
return torch.cat((sparse_emb.flatten(start_dim=1), dense_values), dim=1)
|
|
101
112
|
else:
|
|
102
113
|
raise ValueError("The input features can note be empty")
|
|
103
114
|
else:
|
|
104
115
|
if sparse_exists:
|
|
105
|
-
return sparse_emb #[batch_size, num_features, embed_dim]
|
|
116
|
+
return sparse_emb # [batch_size, num_features, embed_dim]
|
|
106
117
|
else:
|
|
107
118
|
raise ValueError("If keep the original shape:[batch_size, num_features, embed_dim], expected %s in feature list, got %s" % ("SparseFeatures", features))
|
|
108
119
|
|
|
109
120
|
|
|
121
|
+
class InputMask(nn.Module):
|
|
122
|
+
"""Return inputs mask from given features
|
|
123
|
+
|
|
124
|
+
Shape:
|
|
125
|
+
- Input:
|
|
126
|
+
x (dict): {feature_name: feature_value}, sequence feature value is a 2D tensor with shape:`(batch_size, seq_len)`,\
|
|
127
|
+
sparse/dense feature value is a 1D tensor with shape `(batch_size)`.
|
|
128
|
+
features (list or SparseFeature or SequenceFeature): Note that the elements in features are either all instances of SparseFeature or all instances of SequenceFeature.
|
|
129
|
+
- Output:
|
|
130
|
+
- if input Sparse: `(batch_size, num_features)`
|
|
131
|
+
- if input Sequence: `(batch_size, num_features_seq, seq_length)`
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
def __init__(self):
|
|
135
|
+
super().__init__()
|
|
136
|
+
|
|
137
|
+
def forward(self, x, features):
|
|
138
|
+
mask = []
|
|
139
|
+
if not isinstance(features, list):
|
|
140
|
+
features = [features]
|
|
141
|
+
for fea in features:
|
|
142
|
+
if isinstance(fea, SparseFeature) or isinstance(fea, SequenceFeature):
|
|
143
|
+
if fea.padding_idx is not None:
|
|
144
|
+
fea_mask = x[fea.name].long() != fea.padding_idx
|
|
145
|
+
else:
|
|
146
|
+
fea_mask = x[fea.name].long() != -1
|
|
147
|
+
mask.append(fea_mask.unsqueeze(1).float())
|
|
148
|
+
else:
|
|
149
|
+
raise ValueError("Only SparseFeature or SequenceFeature support to get mask.")
|
|
150
|
+
return torch.cat(mask, dim=1)
|
|
151
|
+
|
|
152
|
+
|
|
110
153
|
class LR(nn.Module):
|
|
111
|
-
"""Logistic Regression Module. It is the one Non-linear
|
|
154
|
+
"""Logistic Regression Module. It is the one Non-linear
|
|
112
155
|
transformation for input feature.
|
|
113
156
|
|
|
114
157
|
Args:
|
|
@@ -134,7 +177,7 @@ class LR(nn.Module):
|
|
|
134
177
|
|
|
135
178
|
class ConcatPooling(nn.Module):
|
|
136
179
|
"""Keep the origin sequence embedding shape
|
|
137
|
-
|
|
180
|
+
|
|
138
181
|
Shape:
|
|
139
182
|
- Input: `(batch_size, seq_length, embed_dim)`
|
|
140
183
|
- Output: `(batch_size, seq_length, embed_dim)`
|
|
@@ -143,62 +186,73 @@ class ConcatPooling(nn.Module):
|
|
|
143
186
|
def __init__(self):
|
|
144
187
|
super().__init__()
|
|
145
188
|
|
|
146
|
-
def forward(self, x):
|
|
189
|
+
def forward(self, x, mask=None):
|
|
147
190
|
return x
|
|
148
191
|
|
|
149
192
|
|
|
150
193
|
class AveragePooling(nn.Module):
|
|
151
194
|
"""Pooling the sequence embedding matrix by `mean`.
|
|
152
|
-
|
|
195
|
+
|
|
153
196
|
Shape:
|
|
154
|
-
- Input
|
|
197
|
+
- Input
|
|
198
|
+
x: `(batch_size, seq_length, embed_dim)`
|
|
199
|
+
mask: `(batch_size, 1, seq_length)`
|
|
155
200
|
- Output: `(batch_size, embed_dim)`
|
|
156
201
|
"""
|
|
157
202
|
|
|
158
203
|
def __init__(self):
|
|
159
204
|
super().__init__()
|
|
160
205
|
|
|
161
|
-
def forward(self, x):
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
206
|
+
def forward(self, x, mask=None):
|
|
207
|
+
if mask is None:
|
|
208
|
+
return torch.mean(x, dim=1)
|
|
209
|
+
else:
|
|
210
|
+
sum_pooling_matrix = torch.bmm(mask, x).squeeze(1)
|
|
211
|
+
non_padding_length = mask.sum(dim=-1)
|
|
212
|
+
return sum_pooling_matrix / (non_padding_length.float() + 1e-16)
|
|
166
213
|
|
|
167
214
|
|
|
168
215
|
class SumPooling(nn.Module):
|
|
169
216
|
"""Pooling the sequence embedding matrix by `sum`.
|
|
170
217
|
|
|
171
218
|
Shape:
|
|
172
|
-
- Input
|
|
219
|
+
- Input
|
|
220
|
+
x: `(batch_size, seq_length, embed_dim)`
|
|
221
|
+
mask: `(batch_size, 1, seq_length)`
|
|
173
222
|
- Output: `(batch_size, embed_dim)`
|
|
174
223
|
"""
|
|
175
224
|
|
|
176
225
|
def __init__(self):
|
|
177
226
|
super().__init__()
|
|
178
227
|
|
|
179
|
-
def forward(self, x):
|
|
180
|
-
|
|
228
|
+
def forward(self, x, mask=None):
|
|
229
|
+
if mask is None:
|
|
230
|
+
return torch.sum(x, dim=1)
|
|
231
|
+
else:
|
|
232
|
+
return torch.bmm(mask, x).squeeze(1)
|
|
181
233
|
|
|
182
234
|
|
|
183
235
|
class MLP(nn.Module):
|
|
184
|
-
"""Multi Layer Perceptron Module, it is the most widely used module for
|
|
185
|
-
learning feature. Note we default add `BatchNorm1d` and `Activation`
|
|
236
|
+
"""Multi Layer Perceptron Module, it is the most widely used module for
|
|
237
|
+
learning feature. Note we default add `BatchNorm1d` and `Activation`
|
|
186
238
|
`Dropout` for each `Linear` Module.
|
|
187
239
|
|
|
188
240
|
Args:
|
|
189
241
|
input dim (int): input size of the first Linear Layer.
|
|
190
|
-
|
|
242
|
+
output_layer (bool): whether this MLP module is the output layer. If `True`, then append one Linear(*,1) module.
|
|
243
|
+
dims (list): output size of Linear Layer (default=[]).
|
|
191
244
|
dropout (float): probability of an element to be zeroed (default = 0.5).
|
|
192
245
|
activation (str): the activation function, support `[sigmoid, relu, prelu, dice, softmax]` (default='relu').
|
|
193
|
-
output_layer (bool): whether this MLP module is the output layer. If `True`, then append one Linear(*,1) module.
|
|
194
246
|
|
|
195
247
|
Shape:
|
|
196
248
|
- Input: `(batch_size, input_dim)`
|
|
197
249
|
- Output: `(batch_size, 1)` or `(batch_size, dims[-1])`
|
|
198
250
|
"""
|
|
199
251
|
|
|
200
|
-
def __init__(self, input_dim, dims, dropout=0, activation="relu"
|
|
252
|
+
def __init__(self, input_dim, output_layer=True, dims=None, dropout=0, activation="relu"):
|
|
201
253
|
super().__init__()
|
|
254
|
+
if dims is None:
|
|
255
|
+
dims = []
|
|
202
256
|
layers = list()
|
|
203
257
|
for i_dim in dims:
|
|
204
258
|
layers.append(nn.Linear(input_dim, i_dim))
|
|
@@ -216,7 +270,7 @@ class MLP(nn.Module):
|
|
|
216
270
|
|
|
217
271
|
class FM(nn.Module):
|
|
218
272
|
"""The Factorization Machine module, mentioned in the `DeepFM paper
|
|
219
|
-
<https://arxiv.org/pdf/1703.04247.pdf>`. It is used to learn 2nd-order
|
|
273
|
+
<https://arxiv.org/pdf/1703.04247.pdf>`. It is used to learn 2nd-order
|
|
220
274
|
feature interactions.
|
|
221
275
|
|
|
222
276
|
Args:
|
|
@@ -237,4 +291,704 @@ class FM(nn.Module):
|
|
|
237
291
|
ix = square_of_sum - sum_of_square
|
|
238
292
|
if self.reduce_sum:
|
|
239
293
|
ix = torch.sum(ix, dim=1, keepdim=True)
|
|
240
|
-
return 0.5 * ix
|
|
294
|
+
return 0.5 * ix
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
class CIN(nn.Module):
|
|
298
|
+
"""Compressed Interaction Network
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
input_dim (int): input dim of input tensor.
|
|
302
|
+
cin_size (list[int]): out channels of Conv1d.
|
|
303
|
+
|
|
304
|
+
Shape:
|
|
305
|
+
- Input: `(batch_size, num_features, embed_dim)`
|
|
306
|
+
- Output: `(batch_size, 1)`
|
|
307
|
+
"""
|
|
308
|
+
|
|
309
|
+
def __init__(self, input_dim, cin_size, split_half=True):
|
|
310
|
+
super().__init__()
|
|
311
|
+
self.num_layers = len(cin_size)
|
|
312
|
+
self.split_half = split_half
|
|
313
|
+
self.conv_layers = torch.nn.ModuleList()
|
|
314
|
+
prev_dim, fc_input_dim = input_dim, 0
|
|
315
|
+
for i in range(self.num_layers):
|
|
316
|
+
cross_layer_size = cin_size[i]
|
|
317
|
+
self.conv_layers.append(torch.nn.Conv1d(input_dim * prev_dim, cross_layer_size, 1, stride=1, dilation=1, bias=True))
|
|
318
|
+
if self.split_half and i != self.num_layers - 1:
|
|
319
|
+
cross_layer_size //= 2
|
|
320
|
+
prev_dim = cross_layer_size
|
|
321
|
+
fc_input_dim += prev_dim
|
|
322
|
+
self.fc = torch.nn.Linear(fc_input_dim, 1)
|
|
323
|
+
|
|
324
|
+
def forward(self, x):
|
|
325
|
+
xs = list()
|
|
326
|
+
x0, h = x.unsqueeze(2), x
|
|
327
|
+
for i in range(self.num_layers):
|
|
328
|
+
x = x0 * h.unsqueeze(1)
|
|
329
|
+
batch_size, f0_dim, fin_dim, embed_dim = x.shape
|
|
330
|
+
x = x.view(batch_size, f0_dim * fin_dim, embed_dim)
|
|
331
|
+
x = F.relu(self.conv_layers[i](x))
|
|
332
|
+
if self.split_half and i != self.num_layers - 1:
|
|
333
|
+
x, h = torch.split(x, x.shape[1] // 2, dim=1)
|
|
334
|
+
else:
|
|
335
|
+
h = x
|
|
336
|
+
xs.append(x)
|
|
337
|
+
return self.fc(torch.sum(torch.cat(xs, dim=1), 2))
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
class CrossLayer(nn.Module):
|
|
341
|
+
"""
|
|
342
|
+
Cross layer.
|
|
343
|
+
Args:
|
|
344
|
+
input_dim (int): input dim of input tensor
|
|
345
|
+
"""
|
|
346
|
+
|
|
347
|
+
def __init__(self, input_dim):
|
|
348
|
+
super(CrossLayer, self).__init__()
|
|
349
|
+
self.w = torch.nn.Linear(input_dim, 1, bias=False)
|
|
350
|
+
self.b = torch.nn.Parameter(torch.zeros(input_dim))
|
|
351
|
+
|
|
352
|
+
def forward(self, x_0, x_i):
|
|
353
|
+
x = self.w(x_i) * x_0 + self.b
|
|
354
|
+
return x
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
class CrossNetwork(nn.Module):
|
|
358
|
+
"""CrossNetwork mentioned in the DCN paper.
|
|
359
|
+
|
|
360
|
+
Args:
|
|
361
|
+
input_dim (int): input dim of input tensor
|
|
362
|
+
|
|
363
|
+
Shape:
|
|
364
|
+
- Input: `(batch_size, *)`
|
|
365
|
+
- Output: `(batch_size, *)`
|
|
366
|
+
|
|
367
|
+
"""
|
|
368
|
+
|
|
369
|
+
def __init__(self, input_dim, num_layers):
|
|
370
|
+
super().__init__()
|
|
371
|
+
self.num_layers = num_layers
|
|
372
|
+
self.w = torch.nn.ModuleList([torch.nn.Linear(input_dim, 1, bias=False) for _ in range(num_layers)])
|
|
373
|
+
self.b = torch.nn.ParameterList([torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)])
|
|
374
|
+
|
|
375
|
+
def forward(self, x):
|
|
376
|
+
"""
|
|
377
|
+
:param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
|
|
378
|
+
"""
|
|
379
|
+
x0 = x
|
|
380
|
+
for i in range(self.num_layers):
|
|
381
|
+
xw = self.w[i](x)
|
|
382
|
+
x = x0 * xw + self.b[i] + x
|
|
383
|
+
return x
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
class CrossNetV2(nn.Module):
|
|
387
|
+
|
|
388
|
+
def __init__(self, input_dim, num_layers):
|
|
389
|
+
super().__init__()
|
|
390
|
+
self.num_layers = num_layers
|
|
391
|
+
self.w = torch.nn.ModuleList([torch.nn.Linear(input_dim, input_dim, bias=False) for _ in range(num_layers)])
|
|
392
|
+
self.b = torch.nn.ParameterList([torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)])
|
|
393
|
+
|
|
394
|
+
def forward(self, x):
|
|
395
|
+
x0 = x
|
|
396
|
+
for i in range(self.num_layers):
|
|
397
|
+
x = x0 * self.w[i](x) + self.b[i] + x
|
|
398
|
+
return x
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
class CrossNetMix(nn.Module):
|
|
402
|
+
""" CrossNetMix improves CrossNetwork by:
|
|
403
|
+
1. add MOE to learn feature interactions in different subspaces
|
|
404
|
+
2. add nonlinear transformations in low-dimensional space
|
|
405
|
+
:param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
|
|
406
|
+
"""
|
|
407
|
+
|
|
408
|
+
def __init__(self, input_dim, num_layers=2, low_rank=32, num_experts=4):
|
|
409
|
+
super(CrossNetMix, self).__init__()
|
|
410
|
+
self.num_layers = num_layers
|
|
411
|
+
self.num_experts = num_experts
|
|
412
|
+
|
|
413
|
+
# U: (input_dim, low_rank)
|
|
414
|
+
self.u_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(torch.empty(num_experts, input_dim, low_rank))) for i in range(self.num_layers)])
|
|
415
|
+
# V: (input_dim, low_rank)
|
|
416
|
+
self.v_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(torch.empty(num_experts, input_dim, low_rank))) for i in range(self.num_layers)])
|
|
417
|
+
# C: (low_rank, low_rank)
|
|
418
|
+
self.c_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(torch.empty(num_experts, low_rank, low_rank))) for i in range(self.num_layers)])
|
|
419
|
+
self.gating = nn.ModuleList([nn.Linear(input_dim, 1, bias=False) for i in range(self.num_experts)])
|
|
420
|
+
|
|
421
|
+
self.bias = torch.nn.ParameterList([nn.Parameter(nn.init.zeros_(torch.empty(input_dim, 1))) for i in range(self.num_layers)])
|
|
422
|
+
|
|
423
|
+
def forward(self, x):
|
|
424
|
+
x_0 = x.unsqueeze(2) # (bs, in_features, 1)
|
|
425
|
+
x_l = x_0
|
|
426
|
+
for i in range(self.num_layers):
|
|
427
|
+
output_of_experts = []
|
|
428
|
+
gating_score_experts = []
|
|
429
|
+
for expert_id in range(self.num_experts):
|
|
430
|
+
# (1) G(x_l)
|
|
431
|
+
# compute the gating score by x_l
|
|
432
|
+
gating_score_experts.append(self.gating[expert_id](x_l.squeeze(2)))
|
|
433
|
+
|
|
434
|
+
# (2) E(x_l)
|
|
435
|
+
# project the input x_l to $\mathbb{R}^{r}$
|
|
436
|
+
v_x = torch.matmul(self.v_list[i][expert_id].t(), x_l) # (bs, low_rank, 1)
|
|
437
|
+
|
|
438
|
+
# nonlinear activation in low rank space
|
|
439
|
+
v_x = torch.tanh(v_x)
|
|
440
|
+
v_x = torch.matmul(self.c_list[i][expert_id], v_x)
|
|
441
|
+
v_x = torch.tanh(v_x)
|
|
442
|
+
|
|
443
|
+
# project back to $\mathbb{R}^{d}$
|
|
444
|
+
uv_x = torch.matmul(self.u_list[i][expert_id], v_x) # (bs, in_features, 1)
|
|
445
|
+
|
|
446
|
+
dot_ = uv_x + self.bias[i]
|
|
447
|
+
dot_ = x_0 * dot_ # Hadamard-product
|
|
448
|
+
|
|
449
|
+
output_of_experts.append(dot_.squeeze(2))
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
# (3) mixture of low-rank experts
|
|
453
|
+
output_of_experts = torch.stack(output_of_experts, 2) # (bs, in_features, num_experts)
|
|
454
|
+
gating_score_experts = torch.stack(gating_score_experts, 1) # (bs, num_experts, 1)
|
|
455
|
+
moe_out = torch.matmul(output_of_experts, gating_score_experts.softmax(1))
|
|
456
|
+
x_l = moe_out + x_l # (bs, in_features, 1)
|
|
457
|
+
|
|
458
|
+
x_l = x_l.squeeze() # (bs, in_features)
|
|
459
|
+
return x_l
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
class SENETLayer(nn.Module):
|
|
463
|
+
"""
|
|
464
|
+
A weighted feature gating system in the SENet paper
|
|
465
|
+
Args:
|
|
466
|
+
num_fields (int): number of feature fields
|
|
467
|
+
|
|
468
|
+
Shape:
|
|
469
|
+
- num_fields: `(batch_size, *)`
|
|
470
|
+
- Output: `(batch_size, *)`
|
|
471
|
+
"""
|
|
472
|
+
|
|
473
|
+
def __init__(self, num_fields, reduction_ratio=3):
|
|
474
|
+
super(SENETLayer, self).__init__()
|
|
475
|
+
reduced_size = max(1, int(num_fields / reduction_ratio))
|
|
476
|
+
self.mlp = nn.Sequential(nn.Linear(num_fields, reduced_size, bias=False), nn.ReLU(), nn.Linear(reduced_size, num_fields, bias=False), nn.ReLU())
|
|
477
|
+
|
|
478
|
+
def forward(self, x):
|
|
479
|
+
z = torch.mean(x, dim=-1, out=None)
|
|
480
|
+
a = self.mlp(z)
|
|
481
|
+
v = x * a.unsqueeze(-1)
|
|
482
|
+
return v
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
class BiLinearInteractionLayer(nn.Module):
|
|
486
|
+
"""
|
|
487
|
+
Bilinear feature interaction module, which is an improved model of the FFM model
|
|
488
|
+
Args:
|
|
489
|
+
num_fields (int): number of feature fields
|
|
490
|
+
bilinear_type(str): the type bilinear interaction function
|
|
491
|
+
Shape:
|
|
492
|
+
- num_fields: `(batch_size, *)`
|
|
493
|
+
- Output: `(batch_size, *)`
|
|
494
|
+
"""
|
|
495
|
+
|
|
496
|
+
def __init__(self, input_dim, num_fields, bilinear_type="field_interaction"):
|
|
497
|
+
super(BiLinearInteractionLayer, self).__init__()
|
|
498
|
+
self.bilinear_type = bilinear_type
|
|
499
|
+
if self.bilinear_type == "field_all":
|
|
500
|
+
self.bilinear_layer = nn.Linear(input_dim, input_dim, bias=False)
|
|
501
|
+
elif self.bilinear_type == "field_each":
|
|
502
|
+
self.bilinear_layer = nn.ModuleList([nn.Linear(input_dim, input_dim, bias=False) for i in range(num_fields)])
|
|
503
|
+
elif self.bilinear_type == "field_interaction":
|
|
504
|
+
self.bilinear_layer = nn.ModuleList([nn.Linear(input_dim, input_dim, bias=False) for i, j in combinations(range(num_fields), 2)])
|
|
505
|
+
else:
|
|
506
|
+
raise NotImplementedError()
|
|
507
|
+
|
|
508
|
+
def forward(self, x):
|
|
509
|
+
feature_emb = torch.split(x, 1, dim=1)
|
|
510
|
+
if self.bilinear_type == "field_all":
|
|
511
|
+
bilinear_list = [self.bilinear_layer(v_i) * v_j for v_i, v_j in combinations(feature_emb, 2)]
|
|
512
|
+
elif self.bilinear_type == "field_each":
|
|
513
|
+
bilinear_list = [self.bilinear_layer[i](feature_emb[i]) * feature_emb[j] for i, j in combinations(range(len(feature_emb)), 2)]
|
|
514
|
+
elif self.bilinear_type == "field_interaction":
|
|
515
|
+
bilinear_list = [self.bilinear_layer[i](v[0]) * v[1] for i, v in enumerate(combinations(feature_emb, 2))]
|
|
516
|
+
return torch.cat(bilinear_list, dim=1)
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
class MultiInterestSA(nn.Module):
|
|
520
|
+
"""MultiInterest Attention mentioned in the Comirec paper.
|
|
521
|
+
|
|
522
|
+
Args:
|
|
523
|
+
embedding_dim (int): embedding dim of item embedding
|
|
524
|
+
interest_num (int): num of interest
|
|
525
|
+
hidden_dim (int): hidden dim
|
|
526
|
+
|
|
527
|
+
Shape:
|
|
528
|
+
- Input: seq_emb : (batch,seq,emb)
|
|
529
|
+
mask : (batch,seq,1)
|
|
530
|
+
- Output: `(batch_size, interest_num, embedding_dim)`
|
|
531
|
+
|
|
532
|
+
"""
|
|
533
|
+
|
|
534
|
+
def __init__(self, embedding_dim, interest_num, hidden_dim=None):
|
|
535
|
+
super(MultiInterestSA, self).__init__()
|
|
536
|
+
self.embedding_dim = embedding_dim
|
|
537
|
+
self.interest_num = interest_num
|
|
538
|
+
if hidden_dim is None:
|
|
539
|
+
self.hidden_dim = self.embedding_dim * 4
|
|
540
|
+
self.W1 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.hidden_dim), requires_grad=True)
|
|
541
|
+
self.W2 = torch.nn.Parameter(torch.rand(self.hidden_dim, self.interest_num), requires_grad=True)
|
|
542
|
+
self.W3 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.embedding_dim), requires_grad=True)
|
|
543
|
+
|
|
544
|
+
def forward(self, seq_emb, mask=None):
|
|
545
|
+
H = torch.einsum('bse, ed -> bsd', seq_emb, self.W1).tanh()
|
|
546
|
+
if mask is not None:
|
|
547
|
+
A = torch.einsum('bsd, dk -> bsk', H, self.W2) + - \
|
|
548
|
+
1.e9 * (1 - mask.float())
|
|
549
|
+
A = F.softmax(A, dim=1)
|
|
550
|
+
else:
|
|
551
|
+
A = F.softmax(torch.einsum('bsd, dk -> bsk', H, self.W2), dim=1)
|
|
552
|
+
A = A.permute(0, 2, 1)
|
|
553
|
+
multi_interest_emb = torch.matmul(A, seq_emb)
|
|
554
|
+
return multi_interest_emb
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
class CapsuleNetwork(nn.Module):
|
|
558
|
+
"""CapsuleNetwork mentioned in the Comirec and MIND paper.
|
|
559
|
+
|
|
560
|
+
Args:
|
|
561
|
+
hidden_size (int): embedding dim of item embedding
|
|
562
|
+
seq_len (int): length of the item sequence
|
|
563
|
+
bilinear_type (int): 0 for MIND, 2 for ComirecDR
|
|
564
|
+
interest_num (int): num of interest
|
|
565
|
+
routing_times (int): routing times
|
|
566
|
+
|
|
567
|
+
Shape:
|
|
568
|
+
- Input: seq_emb : (batch,seq,emb)
|
|
569
|
+
mask : (batch,seq,1)
|
|
570
|
+
- Output: `(batch_size, interest_num, embedding_dim)`
|
|
571
|
+
|
|
572
|
+
"""
|
|
573
|
+
|
|
574
|
+
def __init__(self, embedding_dim, seq_len, bilinear_type=2, interest_num=4, routing_times=3, relu_layer=False):
|
|
575
|
+
super(CapsuleNetwork, self).__init__()
|
|
576
|
+
self.embedding_dim = embedding_dim # h
|
|
577
|
+
self.seq_len = seq_len # s
|
|
578
|
+
self.bilinear_type = bilinear_type
|
|
579
|
+
self.interest_num = interest_num
|
|
580
|
+
self.routing_times = routing_times
|
|
581
|
+
|
|
582
|
+
self.relu_layer = relu_layer
|
|
583
|
+
self.stop_grad = True
|
|
584
|
+
self.relu = nn.Sequential(nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), nn.ReLU())
|
|
585
|
+
if self.bilinear_type == 0: # MIND
|
|
586
|
+
self.linear = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
|
|
587
|
+
elif self.bilinear_type == 1:
|
|
588
|
+
self.linear = nn.Linear(self.embedding_dim, self.embedding_dim * self.interest_num, bias=False)
|
|
589
|
+
else:
|
|
590
|
+
self.w = nn.Parameter(torch.Tensor(1, self.seq_len, self.interest_num * self.embedding_dim, self.embedding_dim))
|
|
591
|
+
|
|
592
|
+
def forward(self, item_eb, mask):
|
|
593
|
+
if self.bilinear_type == 0:
|
|
594
|
+
item_eb_hat = self.linear(item_eb)
|
|
595
|
+
item_eb_hat = item_eb_hat.repeat(1, 1, self.interest_num)
|
|
596
|
+
elif self.bilinear_type == 1:
|
|
597
|
+
item_eb_hat = self.linear(item_eb)
|
|
598
|
+
else:
|
|
599
|
+
u = torch.unsqueeze(item_eb, dim=2)
|
|
600
|
+
item_eb_hat = torch.sum(self.w[:, :self.seq_len, :, :] * u, dim=3)
|
|
601
|
+
|
|
602
|
+
item_eb_hat = torch.reshape(item_eb_hat, (-1, self.seq_len, self.interest_num, self.embedding_dim))
|
|
603
|
+
item_eb_hat = torch.transpose(item_eb_hat, 1, 2).contiguous()
|
|
604
|
+
item_eb_hat = torch.reshape(item_eb_hat, (-1, self.interest_num, self.seq_len, self.embedding_dim))
|
|
605
|
+
|
|
606
|
+
if self.stop_grad:
|
|
607
|
+
item_eb_hat_iter = item_eb_hat.detach()
|
|
608
|
+
else:
|
|
609
|
+
item_eb_hat_iter = item_eb_hat
|
|
610
|
+
|
|
611
|
+
if self.bilinear_type > 0:
|
|
612
|
+
capsule_weight = torch.zeros(item_eb_hat.shape[0], self.interest_num, self.seq_len, device=item_eb.device, requires_grad=False)
|
|
613
|
+
else:
|
|
614
|
+
capsule_weight = torch.randn(item_eb_hat.shape[0], self.interest_num, self.seq_len, device=item_eb.device, requires_grad=False)
|
|
615
|
+
|
|
616
|
+
for i in range(self.routing_times): # 动态路由传播3次
|
|
617
|
+
atten_mask = torch.unsqueeze(mask, 1).repeat(1, self.interest_num, 1)
|
|
618
|
+
paddings = torch.zeros_like(atten_mask, dtype=torch.float)
|
|
619
|
+
|
|
620
|
+
capsule_softmax_weight = F.softmax(capsule_weight, dim=-1)
|
|
621
|
+
capsule_softmax_weight = torch.where(torch.eq(atten_mask, 0), paddings, capsule_softmax_weight)
|
|
622
|
+
capsule_softmax_weight = torch.unsqueeze(capsule_softmax_weight, 2)
|
|
623
|
+
|
|
624
|
+
if i < 2:
|
|
625
|
+
interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat_iter)
|
|
626
|
+
cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
|
|
627
|
+
scalar_factor = cap_norm / \
|
|
628
|
+
(1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
|
|
629
|
+
interest_capsule = scalar_factor * interest_capsule
|
|
630
|
+
|
|
631
|
+
delta_weight = torch.matmul(item_eb_hat_iter, torch.transpose(interest_capsule, 2, 3).contiguous())
|
|
632
|
+
delta_weight = torch.reshape(delta_weight, (-1, self.interest_num, self.seq_len))
|
|
633
|
+
capsule_weight = capsule_weight + delta_weight
|
|
634
|
+
else:
|
|
635
|
+
interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat)
|
|
636
|
+
cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
|
|
637
|
+
scalar_factor = cap_norm / \
|
|
638
|
+
(1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
|
|
639
|
+
interest_capsule = scalar_factor * interest_capsule
|
|
640
|
+
|
|
641
|
+
interest_capsule = torch.reshape(interest_capsule, (-1, self.interest_num, self.embedding_dim))
|
|
642
|
+
|
|
643
|
+
if self.relu_layer:
|
|
644
|
+
interest_capsule = self.relu(interest_capsule)
|
|
645
|
+
|
|
646
|
+
return interest_capsule
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
class FFM(nn.Module):
|
|
650
|
+
"""The Field-aware Factorization Machine module, mentioned in the `FFM paper
|
|
651
|
+
<https://dl.acm.org/doi/abs/10.1145/2959100.2959134>`. It explicitly models
|
|
652
|
+
multi-channel second-order feature interactions, with each feature filed
|
|
653
|
+
corresponding to one channel.
|
|
654
|
+
|
|
655
|
+
Args:
|
|
656
|
+
num_fields (int): number of feature fields.
|
|
657
|
+
reduce_sum (bool): whether to sum in embed_dim (default = `True`).
|
|
658
|
+
|
|
659
|
+
Shape:
|
|
660
|
+
- Input: `(batch_size, num_fields, num_fields, embed_dim)`
|
|
661
|
+
- Output: `(batch_size, num_fields*(num_fields-1)/2, 1)` or `(batch_size, num_fields*(num_fields-1)/2, embed_dim)`
|
|
662
|
+
"""
|
|
663
|
+
|
|
664
|
+
def __init__(self, num_fields, reduce_sum=True):
|
|
665
|
+
super().__init__()
|
|
666
|
+
self.num_fields = num_fields
|
|
667
|
+
self.reduce_sum = reduce_sum
|
|
668
|
+
|
|
669
|
+
def forward(self, x):
|
|
670
|
+
# compute (non-redundant) second order field-aware feature crossings
|
|
671
|
+
crossed_embeddings = []
|
|
672
|
+
for i in range(self.num_fields - 1):
|
|
673
|
+
for j in range(i + 1, self.num_fields):
|
|
674
|
+
crossed_embeddings.append(x[:, i, j, :] * x[:, j, i, :])
|
|
675
|
+
crossed_embeddings = torch.stack(crossed_embeddings, dim=1)
|
|
676
|
+
|
|
677
|
+
# if reduce_sum is true, the crossing operation is effectively inner
|
|
678
|
+
# product, other wise Hadamard-product
|
|
679
|
+
if self.reduce_sum:
|
|
680
|
+
crossed_embeddings = torch.sum(crossed_embeddings, dim=-1, keepdim=True)
|
|
681
|
+
return crossed_embeddings
|
|
682
|
+
|
|
683
|
+
|
|
684
|
+
class CEN(nn.Module):
|
|
685
|
+
"""The Compose-Excitation Network module, mentioned in the `FAT-DeepFFM paper
|
|
686
|
+
<https://arxiv.org/abs/1905.06336>`, a modified version of
|
|
687
|
+
`Squeeze-and-Excitation Network” (SENet) (Hu et al., 2017)`. It is used to
|
|
688
|
+
highlight the importance of second-order feature crosses.
|
|
689
|
+
|
|
690
|
+
Args:
|
|
691
|
+
embed_dim (int): the dimensionality of categorical value embedding.
|
|
692
|
+
num_field_crosses (int): the number of second order crosses between feature fields.
|
|
693
|
+
reduction_ratio (int): the between the dimensions of input layer and hidden layer of the MLP module.
|
|
694
|
+
|
|
695
|
+
Shape:
|
|
696
|
+
- Input: `(batch_size, num_fields, num_fields, embed_dim)`
|
|
697
|
+
- Output: `(batch_size, num_fields*(num_fields-1)/2 * embed_dim)`
|
|
698
|
+
"""
|
|
699
|
+
|
|
700
|
+
def __init__(self, embed_dim, num_field_crosses, reduction_ratio):
|
|
701
|
+
super().__init__()
|
|
702
|
+
|
|
703
|
+
# convolution weight (Eq.7 FAT-DeepFFM)
|
|
704
|
+
self.u = torch.nn.Parameter(torch.rand(num_field_crosses, embed_dim), requires_grad=True)
|
|
705
|
+
|
|
706
|
+
# two FC layers that computes the field attention
|
|
707
|
+
self.mlp_att = MLP(num_field_crosses, dims=[num_field_crosses // reduction_ratio, num_field_crosses], output_layer=False, activation="relu")
|
|
708
|
+
|
|
709
|
+
def forward(self, em):
|
|
710
|
+
# compute descriptor vector (Eq.7 FAT-DeepFFM), output shape
|
|
711
|
+
# [batch_size, num_field_crosses]
|
|
712
|
+
d = F.relu((self.u.squeeze(0) * em).sum(-1))
|
|
713
|
+
|
|
714
|
+
# compute field attention (Eq.9), output shape [batch_size,
|
|
715
|
+
# num_field_crosses]
|
|
716
|
+
s = self.mlp_att(d)
|
|
717
|
+
|
|
718
|
+
# rescale original embedding with field attention (Eq.10), output shape
|
|
719
|
+
# [batch_size, num_field_crosses, embed_dim]
|
|
720
|
+
aem = s.unsqueeze(-1) * em
|
|
721
|
+
return aem.flatten(start_dim=1)
|
|
722
|
+
|
|
723
|
+
|
|
724
|
+
# ============ HSTU Layers (新增) ============
|
|
725
|
+
|
|
726
|
+
|
|
727
|
+
class HSTULayer(nn.Module):
|
|
728
|
+
"""Single HSTU layer.
|
|
729
|
+
|
|
730
|
+
This layer implements the core HSTU "sequential transduction unit": a
|
|
731
|
+
multi-head self-attention block with gating and a position-wise FFN, plus
|
|
732
|
+
residual connections and LayerNorm.
|
|
733
|
+
|
|
734
|
+
Args:
|
|
735
|
+
d_model (int): Hidden dimension of the model. Default: 512.
|
|
736
|
+
n_heads (int): Number of attention heads. Default: 8.
|
|
737
|
+
dqk (int): Dimension of query/key per head. Default: 64.
|
|
738
|
+
dv (int): Dimension of value per head. Default: 64.
|
|
739
|
+
dropout (float): Dropout rate applied in the layer. Default: 0.1.
|
|
740
|
+
use_rel_pos_bias (bool): Whether to use relative position bias.
|
|
741
|
+
|
|
742
|
+
Shape:
|
|
743
|
+
- Input: ``(batch_size, seq_len, d_model)``
|
|
744
|
+
- Output: ``(batch_size, seq_len, d_model)``
|
|
745
|
+
|
|
746
|
+
Example:
|
|
747
|
+
>>> layer = HSTULayer(d_model=512, n_heads=8)
|
|
748
|
+
>>> x = torch.randn(32, 256, 512)
|
|
749
|
+
>>> output = layer(x)
|
|
750
|
+
>>> output.shape
|
|
751
|
+
torch.Size([32, 256, 512])
|
|
752
|
+
"""
|
|
753
|
+
|
|
754
|
+
def __init__(self, d_model=512, n_heads=8, dqk=64, dv=64, dropout=0.1, use_rel_pos_bias=True):
|
|
755
|
+
super().__init__()
|
|
756
|
+
self.d_model = d_model
|
|
757
|
+
self.n_heads = n_heads
|
|
758
|
+
self.dqk = dqk
|
|
759
|
+
self.dv = dv
|
|
760
|
+
self.dropout_rate = dropout
|
|
761
|
+
self.use_rel_pos_bias = use_rel_pos_bias
|
|
762
|
+
|
|
763
|
+
# Validate dimensions
|
|
764
|
+
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
|
|
765
|
+
|
|
766
|
+
# Projection 1: d_model -> 2*n_heads*dqk + 2*n_heads*dv
|
|
767
|
+
proj1_out_dim = 2 * n_heads * dqk + 2 * n_heads * dv
|
|
768
|
+
self.proj1 = nn.Linear(d_model, proj1_out_dim)
|
|
769
|
+
|
|
770
|
+
# Projection 2: n_heads*dv -> d_model
|
|
771
|
+
self.proj2 = nn.Linear(n_heads * dv, d_model)
|
|
772
|
+
|
|
773
|
+
# Feed-forward network (FFN)
|
|
774
|
+
# Standard Transformer uses 4*d_model as the hidden dimension of FFN
|
|
775
|
+
ffn_hidden_dim = 4 * d_model
|
|
776
|
+
self.ffn = nn.Sequential(nn.Linear(d_model, ffn_hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(ffn_hidden_dim, d_model), nn.Dropout(dropout))
|
|
777
|
+
|
|
778
|
+
# Layer normalization
|
|
779
|
+
self.norm1 = nn.LayerNorm(d_model)
|
|
780
|
+
self.norm2 = nn.LayerNorm(d_model)
|
|
781
|
+
|
|
782
|
+
# Dropout
|
|
783
|
+
self.dropout = nn.Dropout(dropout)
|
|
784
|
+
|
|
785
|
+
# Scaling factor for attention scores
|
|
786
|
+
self.scale = 1.0 / (dqk**0.5)
|
|
787
|
+
|
|
788
|
+
def forward(self, x, rel_pos_bias=None):
|
|
789
|
+
"""Forward pass of a single HSTU layer.
|
|
790
|
+
|
|
791
|
+
Args:
|
|
792
|
+
x (Tensor): Input tensor of shape ``(batch_size, seq_len, d_model)``.
|
|
793
|
+
rel_pos_bias (Tensor, optional): Relative position bias of shape
|
|
794
|
+
``(1, n_heads, seq_len, seq_len)``.
|
|
795
|
+
|
|
796
|
+
Returns:
|
|
797
|
+
Tensor: Output tensor of shape ``(batch_size, seq_len, d_model)``.
|
|
798
|
+
"""
|
|
799
|
+
batch_size, seq_len, _ = x.shape
|
|
800
|
+
|
|
801
|
+
# Residual connection
|
|
802
|
+
residual = x
|
|
803
|
+
|
|
804
|
+
# Layer normalization
|
|
805
|
+
x = self.norm1(x)
|
|
806
|
+
|
|
807
|
+
# Projection 1: (B, L, D) -> (B, L, 2*H*dqk + 2*H*dv)
|
|
808
|
+
proj_out = self.proj1(x)
|
|
809
|
+
|
|
810
|
+
# Split into Q, K, U, V
|
|
811
|
+
# Q, K: (B, L, H, dqk)
|
|
812
|
+
# U, V: (B, L, H, dv)
|
|
813
|
+
q = proj_out[..., :self.n_heads * self.dqk].reshape(batch_size, seq_len, self.n_heads, self.dqk)
|
|
814
|
+
k = proj_out[..., self.n_heads * self.dqk:2 * self.n_heads * self.dqk].reshape(batch_size, seq_len, self.n_heads, self.dqk)
|
|
815
|
+
u = proj_out[..., 2 * self.n_heads * self.dqk:2 * self.n_heads * self.dqk + self.n_heads * self.dv].reshape(batch_size, seq_len, self.n_heads, self.dv)
|
|
816
|
+
v = proj_out[..., 2 * self.n_heads * self.dqk + self.n_heads * self.dv:].reshape(batch_size, seq_len, self.n_heads, self.dv)
|
|
817
|
+
|
|
818
|
+
# Transpose to (B, H, L, dqk/dv)
|
|
819
|
+
q = q.transpose(1, 2) # (B, H, L, dqk)
|
|
820
|
+
k = k.transpose(1, 2) # (B, H, L, dqk)
|
|
821
|
+
u = u.transpose(1, 2) # (B, H, L, dv)
|
|
822
|
+
v = v.transpose(1, 2) # (B, H, L, dv)
|
|
823
|
+
|
|
824
|
+
# Compute attention scores: (B, H, L, L)
|
|
825
|
+
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
|
826
|
+
|
|
827
|
+
# Add causal mask (prevent attending to future positions)
|
|
828
|
+
# For generative models this is required so that position i only attends
|
|
829
|
+
# to positions <= i.
|
|
830
|
+
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool))
|
|
831
|
+
scores = scores.masked_fill(~causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
|
|
832
|
+
|
|
833
|
+
# Add relative position bias if provided
|
|
834
|
+
if rel_pos_bias is not None:
|
|
835
|
+
scores = scores + rel_pos_bias
|
|
836
|
+
|
|
837
|
+
# Softmax over attention scores
|
|
838
|
+
attn_weights = F.softmax(scores, dim=-1)
|
|
839
|
+
attn_weights = self.dropout(attn_weights)
|
|
840
|
+
|
|
841
|
+
# Attention output: (B, H, L, dv)
|
|
842
|
+
attn_output = torch.matmul(attn_weights, v)
|
|
843
|
+
|
|
844
|
+
# Gating mechanism: apply a learned gate on top of attention output
|
|
845
|
+
# First transpose back to (B, L, H, dv)
|
|
846
|
+
attn_output = attn_output.transpose(1, 2) # (B, L, H, dv)
|
|
847
|
+
u = u.transpose(1, 2) # (B, L, H, dv)
|
|
848
|
+
|
|
849
|
+
# Apply element-wise gate: (B, L, H, dv)
|
|
850
|
+
gated_output = attn_output * torch.sigmoid(u)
|
|
851
|
+
|
|
852
|
+
# Merge heads: (B, L, H*dv)
|
|
853
|
+
gated_output = gated_output.reshape(batch_size, seq_len, self.n_heads * self.dv)
|
|
854
|
+
|
|
855
|
+
# Projection 2: (B, L, H*dv) -> (B, L, D)
|
|
856
|
+
output = self.proj2(gated_output)
|
|
857
|
+
output = self.dropout(output)
|
|
858
|
+
|
|
859
|
+
# Residual connection
|
|
860
|
+
output = output + residual
|
|
861
|
+
|
|
862
|
+
# Second residual block: LayerNorm + FFN + residual connection
|
|
863
|
+
residual = output
|
|
864
|
+
output = self.norm2(output)
|
|
865
|
+
output = self.ffn(output)
|
|
866
|
+
output = output + residual
|
|
867
|
+
|
|
868
|
+
return output
|
|
869
|
+
|
|
870
|
+
|
|
871
|
+
class HSTUBlock(nn.Module):
|
|
872
|
+
"""Stacked HSTU block.
|
|
873
|
+
|
|
874
|
+
This block stacks multiple :class:`HSTULayer` layers to form a deep HSTU
|
|
875
|
+
encoder for sequential recommendation.
|
|
876
|
+
|
|
877
|
+
Args:
|
|
878
|
+
d_model (int): Hidden dimension of the model. Default: 512.
|
|
879
|
+
n_heads (int): Number of attention heads. Default: 8.
|
|
880
|
+
n_layers (int): Number of stacked HSTU layers. Default: 4.
|
|
881
|
+
dqk (int): Dimension of query/key per head. Default: 64.
|
|
882
|
+
dv (int): Dimension of value per head. Default: 64.
|
|
883
|
+
dropout (float): Dropout rate applied in each layer. Default: 0.1.
|
|
884
|
+
use_rel_pos_bias (bool): Whether to use relative position bias.
|
|
885
|
+
|
|
886
|
+
Shape:
|
|
887
|
+
- Input: ``(batch_size, seq_len, d_model)``
|
|
888
|
+
- Output: ``(batch_size, seq_len, d_model)``
|
|
889
|
+
|
|
890
|
+
Example:
|
|
891
|
+
>>> block = HSTUBlock(d_model=512, n_heads=8, n_layers=4)
|
|
892
|
+
>>> x = torch.randn(32, 256, 512)
|
|
893
|
+
>>> output = block(x)
|
|
894
|
+
>>> output.shape
|
|
895
|
+
torch.Size([32, 256, 512])
|
|
896
|
+
"""
|
|
897
|
+
|
|
898
|
+
def __init__(self, d_model=512, n_heads=8, n_layers=4, dqk=64, dv=64, dropout=0.1, use_rel_pos_bias=True):
|
|
899
|
+
super().__init__()
|
|
900
|
+
self.d_model = d_model
|
|
901
|
+
self.n_heads = n_heads
|
|
902
|
+
self.n_layers = n_layers
|
|
903
|
+
|
|
904
|
+
# Create a stack of HSTULayer modules
|
|
905
|
+
self.layers = nn.ModuleList([HSTULayer(d_model=d_model, n_heads=n_heads, dqk=dqk, dv=dv, dropout=dropout, use_rel_pos_bias=use_rel_pos_bias) for _ in range(n_layers)])
|
|
906
|
+
|
|
907
|
+
def forward(self, x, rel_pos_bias=None):
|
|
908
|
+
"""Forward pass through all stacked HSTULayer modules.
|
|
909
|
+
|
|
910
|
+
Args:
|
|
911
|
+
x (Tensor): Input tensor of shape ``(batch_size, seq_len, d_model)``.
|
|
912
|
+
rel_pos_bias (Tensor, optional): Relative position bias shared across
|
|
913
|
+
all layers.
|
|
914
|
+
|
|
915
|
+
Returns:
|
|
916
|
+
Tensor: Output tensor of shape ``(batch_size, seq_len, d_model)``.
|
|
917
|
+
"""
|
|
918
|
+
for layer in self.layers:
|
|
919
|
+
x = layer(x, rel_pos_bias=rel_pos_bias)
|
|
920
|
+
return x
|
|
921
|
+
|
|
922
|
+
|
|
923
|
+
class InteractingLayer(nn.Module):
|
|
924
|
+
"""Multi-head Self-Attention based Interacting Layer, used in AutoInt model.
|
|
925
|
+
|
|
926
|
+
Args:
|
|
927
|
+
embed_dim (int): the embedding dimension.
|
|
928
|
+
num_heads (int): the number of attention heads (default=2).
|
|
929
|
+
dropout (float): the dropout rate (default=0.0).
|
|
930
|
+
residual (bool): whether to use residual connection (default=True).
|
|
931
|
+
|
|
932
|
+
Shape:
|
|
933
|
+
- Input: `(batch_size, num_fields, embed_dim)`
|
|
934
|
+
- Output: `(batch_size, num_fields, embed_dim)`
|
|
935
|
+
"""
|
|
936
|
+
|
|
937
|
+
def __init__(self, embed_dim, num_heads=2, dropout=0.0, residual=True):
|
|
938
|
+
super().__init__()
|
|
939
|
+
if embed_dim % num_heads != 0:
|
|
940
|
+
raise ValueError("embed_dim must be divisible by num_heads")
|
|
941
|
+
|
|
942
|
+
self.embed_dim = embed_dim
|
|
943
|
+
self.num_heads = num_heads
|
|
944
|
+
self.head_dim = embed_dim // num_heads
|
|
945
|
+
self.scale = self.head_dim**-0.5
|
|
946
|
+
self.residual = residual
|
|
947
|
+
|
|
948
|
+
self.W_Q = nn.Linear(embed_dim, embed_dim, bias=False)
|
|
949
|
+
self.W_K = nn.Linear(embed_dim, embed_dim, bias=False)
|
|
950
|
+
self.W_V = nn.Linear(embed_dim, embed_dim, bias=False)
|
|
951
|
+
|
|
952
|
+
# Residual connection
|
|
953
|
+
self.W_Res = nn.Linear(embed_dim, embed_dim, bias=False) if residual else None
|
|
954
|
+
self.dropout = nn.Dropout(dropout) if dropout > 0 else None
|
|
955
|
+
|
|
956
|
+
def forward(self, x):
|
|
957
|
+
"""
|
|
958
|
+
Args:
|
|
959
|
+
x: input tensor with shape (batch_size, num_fields, embed_dim)
|
|
960
|
+
"""
|
|
961
|
+
batch_size, num_fields, embed_dim = x.shape
|
|
962
|
+
|
|
963
|
+
# Linear projections
|
|
964
|
+
Q = self.W_Q(x) # (batch_size, num_fields, embed_dim)
|
|
965
|
+
K = self.W_K(x) # (batch_size, num_fields, embed_dim)
|
|
966
|
+
V = self.W_V(x) # (batch_size, num_fields, embed_dim)
|
|
967
|
+
|
|
968
|
+
# Reshape for multi-head attention
|
|
969
|
+
# (batch_size, num_heads, num_fields, head_dim)
|
|
970
|
+
Q = Q.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
|
|
971
|
+
K = K.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
|
|
972
|
+
V = V.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
|
|
973
|
+
|
|
974
|
+
# Scaled dot-product attention
|
|
975
|
+
# (batch_size, num_heads, num_fields, num_fields)
|
|
976
|
+
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
|
|
977
|
+
attn_weights = F.softmax(attn_scores, dim=-1)
|
|
978
|
+
|
|
979
|
+
if self.dropout is not None:
|
|
980
|
+
attn_weights = self.dropout(attn_weights)
|
|
981
|
+
|
|
982
|
+
# Apply attention to values
|
|
983
|
+
# (batch_size, num_heads, num_fields, head_dim)
|
|
984
|
+
attn_output = torch.matmul(attn_weights, V)
|
|
985
|
+
|
|
986
|
+
# Concatenate heads
|
|
987
|
+
# (batch_size, num_fields, embed_dim)
|
|
988
|
+
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, num_fields, embed_dim)
|
|
989
|
+
|
|
990
|
+
# Residual connection
|
|
991
|
+
if self.residual and self.W_Res is not None:
|
|
992
|
+
attn_output = attn_output + self.W_Res(x)
|
|
993
|
+
|
|
994
|
+
return F.relu(attn_output)
|