nextrec 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +41 -0
- nextrec/__version__.py +1 -0
- nextrec/basic/__init__.py +0 -0
- nextrec/basic/activation.py +92 -0
- nextrec/basic/callback.py +35 -0
- nextrec/basic/dataloader.py +447 -0
- nextrec/basic/features.py +87 -0
- nextrec/basic/layers.py +985 -0
- nextrec/basic/loggers.py +124 -0
- nextrec/basic/metrics.py +557 -0
- nextrec/basic/model.py +1438 -0
- nextrec/data/__init__.py +27 -0
- nextrec/data/data_utils.py +132 -0
- nextrec/data/preprocessor.py +662 -0
- nextrec/loss/__init__.py +35 -0
- nextrec/loss/loss_utils.py +136 -0
- nextrec/loss/match_losses.py +294 -0
- nextrec/models/generative/hstu.py +0 -0
- nextrec/models/generative/tiger.py +0 -0
- nextrec/models/match/__init__.py +13 -0
- nextrec/models/match/dssm.py +200 -0
- nextrec/models/match/dssm_v2.py +162 -0
- nextrec/models/match/mind.py +210 -0
- nextrec/models/match/sdm.py +253 -0
- nextrec/models/match/youtube_dnn.py +172 -0
- nextrec/models/multi_task/esmm.py +129 -0
- nextrec/models/multi_task/mmoe.py +161 -0
- nextrec/models/multi_task/ple.py +260 -0
- nextrec/models/multi_task/share_bottom.py +126 -0
- nextrec/models/ranking/__init__.py +17 -0
- nextrec/models/ranking/afm.py +118 -0
- nextrec/models/ranking/autoint.py +140 -0
- nextrec/models/ranking/dcn.py +120 -0
- nextrec/models/ranking/deepfm.py +95 -0
- nextrec/models/ranking/dien.py +214 -0
- nextrec/models/ranking/din.py +181 -0
- nextrec/models/ranking/fibinet.py +130 -0
- nextrec/models/ranking/fm.py +87 -0
- nextrec/models/ranking/masknet.py +125 -0
- nextrec/models/ranking/pnn.py +128 -0
- nextrec/models/ranking/widedeep.py +105 -0
- nextrec/models/ranking/xdeepfm.py +117 -0
- nextrec/utils/__init__.py +18 -0
- nextrec/utils/common.py +14 -0
- nextrec/utils/embedding.py +19 -0
- nextrec/utils/initializer.py +47 -0
- nextrec/utils/optimizer.py +75 -0
- nextrec-0.1.1.dist-info/METADATA +302 -0
- nextrec-0.1.1.dist-info/RECORD +51 -0
- nextrec-0.1.1.dist-info/WHEEL +4 -0
- nextrec-0.1.1.dist-info/licenses/LICENSE +21 -0
nextrec/basic/layers.py
ADDED
|
@@ -0,0 +1,985 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Layer implementations used across NextRec models.
|
|
3
|
+
|
|
4
|
+
Date: create on 27/10/2025, update on 19/11/2025
|
|
5
|
+
Author:
|
|
6
|
+
Yang Zhou,zyaztec@gmail.com
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from itertools import combinations
|
|
12
|
+
from typing import Iterable, Sequence, Union
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
import torch.nn as nn
|
|
16
|
+
import torch.nn.functional as F
|
|
17
|
+
|
|
18
|
+
from nextrec.basic.activation import activation_layer
|
|
19
|
+
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
20
|
+
from nextrec.utils.initializer import get_initializer_fn
|
|
21
|
+
|
|
22
|
+
Feature = Union[DenseFeature, SparseFeature, SequenceFeature]
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
"PredictionLayer",
|
|
26
|
+
"EmbeddingLayer",
|
|
27
|
+
"InputMask",
|
|
28
|
+
"LR",
|
|
29
|
+
"ConcatPooling",
|
|
30
|
+
"AveragePooling",
|
|
31
|
+
"SumPooling",
|
|
32
|
+
"MLP",
|
|
33
|
+
"FM",
|
|
34
|
+
"FFM",
|
|
35
|
+
"CEN",
|
|
36
|
+
"CIN",
|
|
37
|
+
"CrossLayer",
|
|
38
|
+
"CrossNetwork",
|
|
39
|
+
"CrossNetV2",
|
|
40
|
+
"CrossNetMix",
|
|
41
|
+
"SENETLayer",
|
|
42
|
+
"BiLinearInteractionLayer",
|
|
43
|
+
"MultiInterestSA",
|
|
44
|
+
"CapsuleNetwork",
|
|
45
|
+
"MultiHeadSelfAttention",
|
|
46
|
+
"AttentionPoolingLayer",
|
|
47
|
+
"DynamicGRU",
|
|
48
|
+
"AUGRU",
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class PredictionLayer(nn.Module):
|
|
53
|
+
_CLASSIFICATION_TASKS = {"classification", "binary", "ctr", "ranking", "match", "matching"}
|
|
54
|
+
_REGRESSION_TASKS = {"regression", "continuous"}
|
|
55
|
+
_MULTICLASS_TASKS = {"multiclass", "softmax"}
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
task_type: Union[str, Sequence[str]] = "binary",
|
|
60
|
+
task_dims: Union[int, Sequence[int], None] = None,
|
|
61
|
+
use_bias: bool = True,
|
|
62
|
+
return_logits: bool = False,
|
|
63
|
+
):
|
|
64
|
+
super().__init__()
|
|
65
|
+
|
|
66
|
+
if isinstance(task_type, str):
|
|
67
|
+
self.task_types = [task_type]
|
|
68
|
+
else:
|
|
69
|
+
self.task_types = list(task_type)
|
|
70
|
+
|
|
71
|
+
if len(self.task_types) == 0:
|
|
72
|
+
raise ValueError("At least one task_type must be specified.")
|
|
73
|
+
|
|
74
|
+
if task_dims is None:
|
|
75
|
+
dims = [1] * len(self.task_types)
|
|
76
|
+
elif isinstance(task_dims, int):
|
|
77
|
+
dims = [task_dims]
|
|
78
|
+
else:
|
|
79
|
+
dims = list(task_dims)
|
|
80
|
+
|
|
81
|
+
if len(dims) not in (1, len(self.task_types)):
|
|
82
|
+
raise ValueError(
|
|
83
|
+
"task_dims must be None, a single int (shared), or a sequence of the same length as task_type."
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
if len(dims) == 1 and len(self.task_types) > 1:
|
|
87
|
+
dims = dims * len(self.task_types)
|
|
88
|
+
|
|
89
|
+
self.task_dims = dims
|
|
90
|
+
self.total_dim = sum(self.task_dims)
|
|
91
|
+
self.return_logits = return_logits
|
|
92
|
+
|
|
93
|
+
# Keep slice offsets per task
|
|
94
|
+
start = 0
|
|
95
|
+
self._task_slices: list[tuple[int, int]] = []
|
|
96
|
+
for dim in self.task_dims:
|
|
97
|
+
if dim < 1:
|
|
98
|
+
raise ValueError("Each task dimension must be >= 1.")
|
|
99
|
+
self._task_slices.append((start, start + dim))
|
|
100
|
+
start += dim
|
|
101
|
+
|
|
102
|
+
if use_bias:
|
|
103
|
+
self.bias = nn.Parameter(torch.zeros(self.total_dim))
|
|
104
|
+
else:
|
|
105
|
+
self.register_parameter("bias", None)
|
|
106
|
+
|
|
107
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
108
|
+
if x.dim() == 1:
|
|
109
|
+
x = x.unsqueeze(-1)
|
|
110
|
+
|
|
111
|
+
if x.shape[-1] != self.total_dim:
|
|
112
|
+
raise ValueError(
|
|
113
|
+
f"Input last dimension ({x.shape[-1]}) does not match expected total dimension ({self.total_dim})."
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
logits = x if self.bias is None else x + self.bias
|
|
117
|
+
outputs: list[torch.Tensor] = []
|
|
118
|
+
|
|
119
|
+
for task_type, (start, end) in zip(self.task_types, self._task_slices):
|
|
120
|
+
task_logits = logits[..., start:end]
|
|
121
|
+
if self.return_logits:
|
|
122
|
+
outputs.append(task_logits)
|
|
123
|
+
continue
|
|
124
|
+
|
|
125
|
+
activation = self._get_activation(task_type)
|
|
126
|
+
outputs.append(activation(task_logits))
|
|
127
|
+
|
|
128
|
+
result = torch.cat(outputs, dim=-1)
|
|
129
|
+
if result.shape[-1] == 1:
|
|
130
|
+
result = result.squeeze(-1)
|
|
131
|
+
return result
|
|
132
|
+
|
|
133
|
+
def _get_activation(self, task_type: str):
|
|
134
|
+
task = task_type.lower()
|
|
135
|
+
if task in self._CLASSIFICATION_TASKS:
|
|
136
|
+
return torch.sigmoid
|
|
137
|
+
if task in self._REGRESSION_TASKS:
|
|
138
|
+
return lambda x: x
|
|
139
|
+
if task in self._MULTICLASS_TASKS:
|
|
140
|
+
return lambda x: torch.softmax(x, dim=-1)
|
|
141
|
+
raise ValueError(f"Unsupported task_type '{task_type}'.")
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class EmbeddingLayer(nn.Module):
|
|
145
|
+
def __init__(self, features: Sequence[Feature]):
|
|
146
|
+
super().__init__()
|
|
147
|
+
self.features = list(features)
|
|
148
|
+
self.embed_dict = nn.ModuleDict()
|
|
149
|
+
self.dense_transforms = nn.ModuleDict()
|
|
150
|
+
self.dense_input_dims: dict[str, int] = {}
|
|
151
|
+
|
|
152
|
+
for feature in self.features:
|
|
153
|
+
if isinstance(feature, (SparseFeature, SequenceFeature)):
|
|
154
|
+
if feature.embedding_name in self.embed_dict:
|
|
155
|
+
continue
|
|
156
|
+
|
|
157
|
+
embedding = nn.Embedding(
|
|
158
|
+
num_embeddings=feature.vocab_size,
|
|
159
|
+
embedding_dim=feature.embedding_dim,
|
|
160
|
+
padding_idx=feature.padding_idx,
|
|
161
|
+
)
|
|
162
|
+
embedding.weight.requires_grad = feature.trainable
|
|
163
|
+
|
|
164
|
+
initialization = get_initializer_fn(
|
|
165
|
+
init_type=feature.init_type,
|
|
166
|
+
activation="linear",
|
|
167
|
+
param=feature.init_params,
|
|
168
|
+
)
|
|
169
|
+
initialization(embedding.weight)
|
|
170
|
+
self.embed_dict[feature.embedding_name] = embedding
|
|
171
|
+
|
|
172
|
+
elif isinstance(feature, DenseFeature):
|
|
173
|
+
if feature.name in self.dense_transforms:
|
|
174
|
+
continue
|
|
175
|
+
in_dim = max(int(getattr(feature, "input_dim", 1)), 1)
|
|
176
|
+
out_dim = max(int(getattr(feature, "embedding_dim", None) or in_dim), 1)
|
|
177
|
+
dense_linear = nn.Linear(in_dim, out_dim, bias=True)
|
|
178
|
+
nn.init.xavier_uniform_(dense_linear.weight)
|
|
179
|
+
nn.init.zeros_(dense_linear.bias)
|
|
180
|
+
self.dense_transforms[feature.name] = dense_linear
|
|
181
|
+
self.dense_input_dims[feature.name] = in_dim
|
|
182
|
+
|
|
183
|
+
else:
|
|
184
|
+
raise TypeError(f"Unsupported feature type: {type(feature)}")
|
|
185
|
+
self.output_dim = self._compute_output_dim()
|
|
186
|
+
|
|
187
|
+
def forward(
|
|
188
|
+
self,
|
|
189
|
+
x: dict[str, torch.Tensor],
|
|
190
|
+
features: Sequence[Feature],
|
|
191
|
+
squeeze_dim: bool = False,
|
|
192
|
+
) -> torch.Tensor:
|
|
193
|
+
sparse_embeds: list[torch.Tensor] = []
|
|
194
|
+
dense_embeds: list[torch.Tensor] = []
|
|
195
|
+
|
|
196
|
+
for feature in features:
|
|
197
|
+
if isinstance(feature, SparseFeature):
|
|
198
|
+
embed = self.embed_dict[feature.embedding_name]
|
|
199
|
+
sparse_embeds.append(embed(x[feature.name].long()).unsqueeze(1))
|
|
200
|
+
|
|
201
|
+
elif isinstance(feature, SequenceFeature):
|
|
202
|
+
seq_input = x[feature.name].long()
|
|
203
|
+
if feature.max_len is not None and seq_input.size(1) > feature.max_len:
|
|
204
|
+
seq_input = seq_input[:, -feature.max_len :]
|
|
205
|
+
|
|
206
|
+
embed = self.embed_dict[feature.embedding_name]
|
|
207
|
+
seq_emb = embed(seq_input) # [B, seq_len, emb_dim]
|
|
208
|
+
|
|
209
|
+
if feature.combiner == "mean":
|
|
210
|
+
pooling_layer = AveragePooling()
|
|
211
|
+
elif feature.combiner == "sum":
|
|
212
|
+
pooling_layer = SumPooling()
|
|
213
|
+
elif feature.combiner == "concat":
|
|
214
|
+
pooling_layer = ConcatPooling()
|
|
215
|
+
else:
|
|
216
|
+
raise ValueError(f"Unknown combiner for {feature.name}: {feature.combiner}")
|
|
217
|
+
|
|
218
|
+
feature_mask = InputMask()(x, feature, seq_input)
|
|
219
|
+
sparse_embeds.append(pooling_layer(seq_emb, feature_mask).unsqueeze(1))
|
|
220
|
+
|
|
221
|
+
elif isinstance(feature, DenseFeature):
|
|
222
|
+
dense_embeds.append(self._project_dense(feature, x))
|
|
223
|
+
|
|
224
|
+
if squeeze_dim:
|
|
225
|
+
flattened_sparse = [emb.flatten(start_dim=1) for emb in sparse_embeds]
|
|
226
|
+
pieces = []
|
|
227
|
+
if flattened_sparse:
|
|
228
|
+
pieces.append(torch.cat(flattened_sparse, dim=1))
|
|
229
|
+
if dense_embeds:
|
|
230
|
+
pieces.append(torch.cat(dense_embeds, dim=1))
|
|
231
|
+
|
|
232
|
+
if not pieces:
|
|
233
|
+
raise ValueError("No input features found for EmbeddingLayer.")
|
|
234
|
+
|
|
235
|
+
return pieces[0] if len(pieces) == 1 else torch.cat(pieces, dim=1)
|
|
236
|
+
|
|
237
|
+
# squeeze_dim=False requires embeddings with identical last dimension
|
|
238
|
+
output_embeddings = list(sparse_embeds)
|
|
239
|
+
if dense_embeds:
|
|
240
|
+
target_dim = None
|
|
241
|
+
if output_embeddings:
|
|
242
|
+
target_dim = output_embeddings[0].shape[-1]
|
|
243
|
+
elif len({emb.shape[-1] for emb in dense_embeds}) == 1:
|
|
244
|
+
target_dim = dense_embeds[0].shape[-1]
|
|
245
|
+
|
|
246
|
+
if target_dim is not None:
|
|
247
|
+
aligned_dense = [
|
|
248
|
+
emb.unsqueeze(1) for emb in dense_embeds if emb.shape[-1] == target_dim
|
|
249
|
+
]
|
|
250
|
+
output_embeddings.extend(aligned_dense)
|
|
251
|
+
|
|
252
|
+
if not output_embeddings:
|
|
253
|
+
raise ValueError(
|
|
254
|
+
"squeeze_dim=False requires at least one sparse/sequence feature or "
|
|
255
|
+
"dense features with identical projected dimensions."
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
return torch.cat(output_embeddings, dim=1)
|
|
259
|
+
|
|
260
|
+
def _project_dense(self, feature: DenseFeature, x: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
261
|
+
if feature.name not in x:
|
|
262
|
+
raise KeyError(f"Dense feature '{feature.name}' is missing from input.")
|
|
263
|
+
|
|
264
|
+
value = x[feature.name].float()
|
|
265
|
+
if value.dim() == 1:
|
|
266
|
+
value = value.unsqueeze(-1)
|
|
267
|
+
else:
|
|
268
|
+
value = value.view(value.size(0), -1)
|
|
269
|
+
|
|
270
|
+
dense_layer = self.dense_transforms[feature.name]
|
|
271
|
+
expected_in_dim = self.dense_input_dims[feature.name]
|
|
272
|
+
if value.shape[1] != expected_in_dim:
|
|
273
|
+
raise ValueError(
|
|
274
|
+
f"Dense feature '{feature.name}' expects {expected_in_dim} inputs but "
|
|
275
|
+
f"got {value.shape[1]}."
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
return dense_layer(value)
|
|
279
|
+
|
|
280
|
+
def _compute_output_dim(self):
|
|
281
|
+
return
|
|
282
|
+
|
|
283
|
+
class InputMask(nn.Module):
|
|
284
|
+
"""Utility module to build sequence masks for pooling layers."""
|
|
285
|
+
|
|
286
|
+
def __init__(self):
|
|
287
|
+
super().__init__()
|
|
288
|
+
|
|
289
|
+
def forward(self, x, fea, seq_tensor=None):
|
|
290
|
+
values = seq_tensor if seq_tensor is not None else x[fea.name]
|
|
291
|
+
if fea.padding_idx is not None:
|
|
292
|
+
mask = (values.long() != fea.padding_idx)
|
|
293
|
+
else:
|
|
294
|
+
mask = (values.long() != 0)
|
|
295
|
+
if mask.dim() == 1:
|
|
296
|
+
mask = mask.unsqueeze(-1)
|
|
297
|
+
return mask.unsqueeze(1).float()
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
class LR(nn.Module):
|
|
301
|
+
"""Wide component from Wide&Deep (Cheng et al., 2016)."""
|
|
302
|
+
|
|
303
|
+
def __init__(self, input_dim, sigmoid=False):
|
|
304
|
+
super().__init__()
|
|
305
|
+
self.sigmoid = sigmoid
|
|
306
|
+
self.fc = nn.Linear(input_dim, 1, bias=True)
|
|
307
|
+
|
|
308
|
+
def forward(self, x):
|
|
309
|
+
if self.sigmoid:
|
|
310
|
+
return torch.sigmoid(self.fc(x))
|
|
311
|
+
else:
|
|
312
|
+
return self.fc(x)
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
class ConcatPooling(nn.Module):
|
|
316
|
+
"""Concatenates sequence embeddings along the temporal dimension."""
|
|
317
|
+
|
|
318
|
+
def __init__(self):
|
|
319
|
+
super().__init__()
|
|
320
|
+
|
|
321
|
+
def forward(self, x, mask=None):
|
|
322
|
+
return x.flatten(start_dim=1, end_dim=2)
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
class AveragePooling(nn.Module):
|
|
326
|
+
"""Mean pooling with optional padding mask."""
|
|
327
|
+
|
|
328
|
+
def __init__(self):
|
|
329
|
+
super().__init__()
|
|
330
|
+
|
|
331
|
+
def forward(self, x, mask=None):
|
|
332
|
+
if mask is None:
|
|
333
|
+
return torch.mean(x, dim=1)
|
|
334
|
+
else:
|
|
335
|
+
sum_pooling_matrix = torch.bmm(mask, x).squeeze(1)
|
|
336
|
+
non_padding_length = mask.sum(dim=-1)
|
|
337
|
+
return sum_pooling_matrix / (non_padding_length.float() + 1e-16)
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
class SumPooling(nn.Module):
|
|
341
|
+
"""Sum pooling with optional padding mask."""
|
|
342
|
+
|
|
343
|
+
def __init__(self):
|
|
344
|
+
super().__init__()
|
|
345
|
+
|
|
346
|
+
def forward(self, x, mask=None):
|
|
347
|
+
if mask is None:
|
|
348
|
+
return torch.sum(x, dim=1)
|
|
349
|
+
else:
|
|
350
|
+
return torch.bmm(mask, x).squeeze(1)
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
class MLP(nn.Module):
|
|
354
|
+
"""Stacked fully connected layers used in the deep component."""
|
|
355
|
+
|
|
356
|
+
def __init__(self, input_dim, output_layer=True, dims=None, dropout=0, activation="relu"):
|
|
357
|
+
super().__init__()
|
|
358
|
+
if dims is None:
|
|
359
|
+
dims = []
|
|
360
|
+
layers = list()
|
|
361
|
+
for i_dim in dims:
|
|
362
|
+
layers.append(nn.Linear(input_dim, i_dim))
|
|
363
|
+
layers.append(nn.BatchNorm1d(i_dim))
|
|
364
|
+
layers.append(activation_layer(activation))
|
|
365
|
+
layers.append(nn.Dropout(p=dropout))
|
|
366
|
+
input_dim = i_dim
|
|
367
|
+
if output_layer:
|
|
368
|
+
layers.append(nn.Linear(input_dim, 1))
|
|
369
|
+
self.mlp = nn.Sequential(*layers)
|
|
370
|
+
|
|
371
|
+
def forward(self, x):
|
|
372
|
+
return self.mlp(x)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
class FM(nn.Module):
|
|
376
|
+
"""Factorization Machine (Rendle, 2010) second-order interaction term."""
|
|
377
|
+
|
|
378
|
+
def __init__(self, reduce_sum=True):
|
|
379
|
+
super().__init__()
|
|
380
|
+
self.reduce_sum = reduce_sum
|
|
381
|
+
|
|
382
|
+
def forward(self, x):
|
|
383
|
+
square_of_sum = torch.sum(x, dim=1)**2
|
|
384
|
+
sum_of_square = torch.sum(x**2, dim=1)
|
|
385
|
+
ix = square_of_sum - sum_of_square
|
|
386
|
+
if self.reduce_sum:
|
|
387
|
+
ix = torch.sum(ix, dim=1, keepdim=True)
|
|
388
|
+
return 0.5 * ix
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
class CIN(nn.Module):
|
|
392
|
+
"""Compressed Interaction Network from xDeepFM (Lian et al., 2018)."""
|
|
393
|
+
|
|
394
|
+
def __init__(self, input_dim, cin_size, split_half=True):
|
|
395
|
+
super().__init__()
|
|
396
|
+
self.num_layers = len(cin_size)
|
|
397
|
+
self.split_half = split_half
|
|
398
|
+
self.conv_layers = torch.nn.ModuleList()
|
|
399
|
+
prev_dim, fc_input_dim = input_dim, 0
|
|
400
|
+
for i in range(self.num_layers):
|
|
401
|
+
cross_layer_size = cin_size[i]
|
|
402
|
+
self.conv_layers.append(torch.nn.Conv1d(input_dim * prev_dim, cross_layer_size, 1, stride=1, dilation=1, bias=True))
|
|
403
|
+
if self.split_half and i != self.num_layers - 1:
|
|
404
|
+
cross_layer_size //= 2
|
|
405
|
+
prev_dim = cross_layer_size
|
|
406
|
+
fc_input_dim += prev_dim
|
|
407
|
+
self.fc = torch.nn.Linear(fc_input_dim, 1)
|
|
408
|
+
|
|
409
|
+
def forward(self, x):
|
|
410
|
+
xs = list()
|
|
411
|
+
x0, h = x.unsqueeze(2), x
|
|
412
|
+
for i in range(self.num_layers):
|
|
413
|
+
x = x0 * h.unsqueeze(1)
|
|
414
|
+
batch_size, f0_dim, fin_dim, embed_dim = x.shape
|
|
415
|
+
x = x.view(batch_size, f0_dim * fin_dim, embed_dim)
|
|
416
|
+
x = F.relu(self.conv_layers[i](x))
|
|
417
|
+
if self.split_half and i != self.num_layers - 1:
|
|
418
|
+
x, h = torch.split(x, x.shape[1] // 2, dim=1)
|
|
419
|
+
else:
|
|
420
|
+
h = x
|
|
421
|
+
xs.append(x)
|
|
422
|
+
return self.fc(torch.sum(torch.cat(xs, dim=1), 2))
|
|
423
|
+
|
|
424
|
+
class CrossLayer(nn.Module):
|
|
425
|
+
"""Single cross layer used in DCN (Wang et al., 2017)."""
|
|
426
|
+
|
|
427
|
+
def __init__(self, input_dim):
|
|
428
|
+
super(CrossLayer, self).__init__()
|
|
429
|
+
self.w = torch.nn.Linear(input_dim, 1, bias=False)
|
|
430
|
+
self.b = torch.nn.Parameter(torch.zeros(input_dim))
|
|
431
|
+
|
|
432
|
+
def forward(self, x_0, x_i):
|
|
433
|
+
x = self.w(x_i) * x_0 + self.b
|
|
434
|
+
return x
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
class CrossNetwork(nn.Module):
|
|
438
|
+
"""Stacked Cross Layers from DCN (Wang et al., 2017)."""
|
|
439
|
+
|
|
440
|
+
def __init__(self, input_dim, num_layers):
|
|
441
|
+
super().__init__()
|
|
442
|
+
self.num_layers = num_layers
|
|
443
|
+
self.w = torch.nn.ModuleList([torch.nn.Linear(input_dim, 1, bias=False) for _ in range(num_layers)])
|
|
444
|
+
self.b = torch.nn.ParameterList([torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)])
|
|
445
|
+
|
|
446
|
+
def forward(self, x):
|
|
447
|
+
"""
|
|
448
|
+
:param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
|
|
449
|
+
"""
|
|
450
|
+
x0 = x
|
|
451
|
+
for i in range(self.num_layers):
|
|
452
|
+
xw = self.w[i](x)
|
|
453
|
+
x = x0 * xw + self.b[i] + x
|
|
454
|
+
return x
|
|
455
|
+
|
|
456
|
+
class CrossNetV2(nn.Module):
|
|
457
|
+
"""Vector-wise cross network proposed in DCN V2 (Wang et al., 2021)."""
|
|
458
|
+
def __init__(self, input_dim, num_layers):
|
|
459
|
+
super().__init__()
|
|
460
|
+
self.num_layers = num_layers
|
|
461
|
+
self.w = torch.nn.ModuleList([torch.nn.Linear(input_dim, input_dim, bias=False) for _ in range(num_layers)])
|
|
462
|
+
self.b = torch.nn.ParameterList([torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)])
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
def forward(self, x):
|
|
466
|
+
x0 = x
|
|
467
|
+
for i in range(self.num_layers):
|
|
468
|
+
x =x0*self.w[i](x) + self.b[i] + x
|
|
469
|
+
return x
|
|
470
|
+
|
|
471
|
+
class CrossNetMix(nn.Module):
|
|
472
|
+
"""Mixture of low-rank cross experts from DCN V2 (Wang et al., 2021)."""
|
|
473
|
+
|
|
474
|
+
def __init__(self, input_dim, num_layers=2, low_rank=32, num_experts=4):
|
|
475
|
+
super(CrossNetMix, self).__init__()
|
|
476
|
+
self.num_layers = num_layers
|
|
477
|
+
self.num_experts = num_experts
|
|
478
|
+
|
|
479
|
+
# U: (input_dim, low_rank)
|
|
480
|
+
self.u_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
|
|
481
|
+
torch.empty(num_experts, input_dim, low_rank))) for i in range(self.num_layers)])
|
|
482
|
+
# V: (input_dim, low_rank)
|
|
483
|
+
self.v_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
|
|
484
|
+
torch.empty(num_experts, input_dim, low_rank))) for i in range(self.num_layers)])
|
|
485
|
+
# C: (low_rank, low_rank)
|
|
486
|
+
self.c_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
|
|
487
|
+
torch.empty(num_experts, low_rank, low_rank))) for i in range(self.num_layers)])
|
|
488
|
+
self.gating = nn.ModuleList([nn.Linear(input_dim, 1, bias=False) for i in range(self.num_experts)])
|
|
489
|
+
|
|
490
|
+
self.bias = torch.nn.ParameterList([nn.Parameter(nn.init.zeros_(
|
|
491
|
+
torch.empty(input_dim, 1))) for i in range(self.num_layers)])
|
|
492
|
+
|
|
493
|
+
def forward(self, x):
|
|
494
|
+
x_0 = x.unsqueeze(2) # (bs, in_features, 1)
|
|
495
|
+
x_l = x_0
|
|
496
|
+
for i in range(self.num_layers):
|
|
497
|
+
output_of_experts = []
|
|
498
|
+
gating_score_experts = []
|
|
499
|
+
for expert_id in range(self.num_experts):
|
|
500
|
+
# (1) G(x_l)
|
|
501
|
+
# compute the gating score by x_l
|
|
502
|
+
gating_score_experts.append(self.gating[expert_id](x_l.squeeze(2)))
|
|
503
|
+
|
|
504
|
+
# (2) E(x_l)
|
|
505
|
+
# project the input x_l to $\mathbb{R}^{r}$
|
|
506
|
+
v_x = torch.matmul(self.v_list[i][expert_id].t(), x_l) # (bs, low_rank, 1)
|
|
507
|
+
|
|
508
|
+
# nonlinear activation in low rank space
|
|
509
|
+
v_x = torch.tanh(v_x)
|
|
510
|
+
v_x = torch.matmul(self.c_list[i][expert_id], v_x)
|
|
511
|
+
v_x = torch.tanh(v_x)
|
|
512
|
+
|
|
513
|
+
# project back to $\mathbb{R}^{d}$
|
|
514
|
+
uv_x = torch.matmul(self.u_list[i][expert_id], v_x) # (bs, in_features, 1)
|
|
515
|
+
|
|
516
|
+
dot_ = uv_x + self.bias[i]
|
|
517
|
+
dot_ = x_0 * dot_ # Hadamard-product
|
|
518
|
+
|
|
519
|
+
output_of_experts.append(dot_.squeeze(2))
|
|
520
|
+
|
|
521
|
+
# (3) mixture of low-rank experts
|
|
522
|
+
output_of_experts = torch.stack(output_of_experts, 2) # (bs, in_features, num_experts)
|
|
523
|
+
gating_score_experts = torch.stack(gating_score_experts, 1) # (bs, num_experts, 1)
|
|
524
|
+
moe_out = torch.matmul(output_of_experts, gating_score_experts.softmax(1))
|
|
525
|
+
x_l = moe_out + x_l # (bs, in_features, 1)
|
|
526
|
+
|
|
527
|
+
x_l = x_l.squeeze() # (bs, in_features)
|
|
528
|
+
return x_l
|
|
529
|
+
|
|
530
|
+
class SENETLayer(nn.Module):
|
|
531
|
+
"""Squeeze-and-Excitation block adopted by FiBiNET (Huang et al., 2019)."""
|
|
532
|
+
|
|
533
|
+
def __init__(self, num_fields, reduction_ratio=3):
|
|
534
|
+
super(SENETLayer, self).__init__()
|
|
535
|
+
reduced_size = max(1, int(num_fields/ reduction_ratio))
|
|
536
|
+
self.mlp = nn.Sequential(nn.Linear(num_fields, reduced_size, bias=False),
|
|
537
|
+
nn.ReLU(),
|
|
538
|
+
nn.Linear(reduced_size, num_fields, bias=False),
|
|
539
|
+
nn.ReLU())
|
|
540
|
+
def forward(self, x):
|
|
541
|
+
z = torch.mean(x, dim=-1, out=None)
|
|
542
|
+
a = self.mlp(z)
|
|
543
|
+
v = x*a.unsqueeze(-1)
|
|
544
|
+
return v
|
|
545
|
+
|
|
546
|
+
class BiLinearInteractionLayer(nn.Module):
|
|
547
|
+
"""Bilinear feature interaction from FiBiNET (Huang et al., 2019)."""
|
|
548
|
+
|
|
549
|
+
def __init__(self, input_dim, num_fields, bilinear_type = "field_interaction"):
|
|
550
|
+
super(BiLinearInteractionLayer, self).__init__()
|
|
551
|
+
self.bilinear_type = bilinear_type
|
|
552
|
+
if self.bilinear_type == "field_all":
|
|
553
|
+
self.bilinear_layer = nn.Linear(input_dim, input_dim, bias=False)
|
|
554
|
+
elif self.bilinear_type == "field_each":
|
|
555
|
+
self.bilinear_layer = nn.ModuleList([nn.Linear(input_dim, input_dim, bias=False) for i in range(num_fields)])
|
|
556
|
+
elif self.bilinear_type == "field_interaction":
|
|
557
|
+
self.bilinear_layer = nn.ModuleList([nn.Linear(input_dim, input_dim, bias=False) for i,j in combinations(range(num_fields), 2)])
|
|
558
|
+
else:
|
|
559
|
+
raise NotImplementedError()
|
|
560
|
+
|
|
561
|
+
def forward(self, x):
|
|
562
|
+
feature_emb = torch.split(x, 1, dim=1)
|
|
563
|
+
if self.bilinear_type == "field_all":
|
|
564
|
+
bilinear_list = [self.bilinear_layer(v_i)*v_j for v_i, v_j in combinations(feature_emb, 2)]
|
|
565
|
+
elif self.bilinear_type == "field_each":
|
|
566
|
+
bilinear_list = [self.bilinear_layer[i](feature_emb[i])*feature_emb[j] for i,j in combinations(range(len(feature_emb)), 2)]
|
|
567
|
+
elif self.bilinear_type == "field_interaction":
|
|
568
|
+
bilinear_list = [self.bilinear_layer[i](v[0])*v[1] for i,v in enumerate(combinations(feature_emb, 2))]
|
|
569
|
+
return torch.cat(bilinear_list, dim=1)
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
class MultiInterestSA(nn.Module):
|
|
573
|
+
"""Multi-interest self-attention extractor from MIND (Li et al., 2019)."""
|
|
574
|
+
|
|
575
|
+
def __init__(self, embedding_dim, interest_num, hidden_dim=None):
|
|
576
|
+
super(MultiInterestSA, self).__init__()
|
|
577
|
+
self.embedding_dim = embedding_dim
|
|
578
|
+
self.interest_num = interest_num
|
|
579
|
+
if hidden_dim == None:
|
|
580
|
+
self.hidden_dim = self.embedding_dim * 4
|
|
581
|
+
self.W1 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.hidden_dim), requires_grad=True)
|
|
582
|
+
self.W2 = torch.nn.Parameter(torch.rand(self.hidden_dim, self.interest_num), requires_grad=True)
|
|
583
|
+
self.W3 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.embedding_dim), requires_grad=True)
|
|
584
|
+
|
|
585
|
+
def forward(self, seq_emb, mask=None):
|
|
586
|
+
H = torch.einsum('bse, ed -> bsd', seq_emb, self.W1).tanh()
|
|
587
|
+
if mask != None:
|
|
588
|
+
A = torch.einsum('bsd, dk -> bsk', H, self.W2) + -1.e9 * (1 - mask.float())
|
|
589
|
+
A = F.softmax(A, dim=1)
|
|
590
|
+
else:
|
|
591
|
+
A = F.softmax(torch.einsum('bsd, dk -> bsk', H, self.W2), dim=1)
|
|
592
|
+
A = A.permute(0, 2, 1)
|
|
593
|
+
multi_interest_emb = torch.matmul(A, seq_emb)
|
|
594
|
+
return multi_interest_emb
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
class CapsuleNetwork(nn.Module):
|
|
598
|
+
"""Dynamic routing capsule network used in MIND (Li et al., 2019)."""
|
|
599
|
+
|
|
600
|
+
def __init__(self, embedding_dim, seq_len, bilinear_type=2, interest_num=4, routing_times=3, relu_layer=False):
|
|
601
|
+
super(CapsuleNetwork, self).__init__()
|
|
602
|
+
self.embedding_dim = embedding_dim # h
|
|
603
|
+
self.seq_len = seq_len # s
|
|
604
|
+
self.bilinear_type = bilinear_type
|
|
605
|
+
self.interest_num = interest_num
|
|
606
|
+
self.routing_times = routing_times
|
|
607
|
+
|
|
608
|
+
self.relu_layer = relu_layer
|
|
609
|
+
self.stop_grad = True
|
|
610
|
+
self.relu = nn.Sequential(nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), nn.ReLU())
|
|
611
|
+
if self.bilinear_type == 0: # MIND
|
|
612
|
+
self.linear = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
|
|
613
|
+
elif self.bilinear_type == 1:
|
|
614
|
+
self.linear = nn.Linear(self.embedding_dim, self.embedding_dim * self.interest_num, bias=False)
|
|
615
|
+
else:
|
|
616
|
+
self.w = nn.Parameter(torch.Tensor(1, self.seq_len, self.interest_num * self.embedding_dim, self.embedding_dim))
|
|
617
|
+
nn.init.xavier_uniform_(self.w)
|
|
618
|
+
|
|
619
|
+
def forward(self, item_eb, mask):
|
|
620
|
+
if self.bilinear_type == 0:
|
|
621
|
+
item_eb_hat = self.linear(item_eb)
|
|
622
|
+
item_eb_hat = item_eb_hat.repeat(1, 1, self.interest_num)
|
|
623
|
+
elif self.bilinear_type == 1:
|
|
624
|
+
item_eb_hat = self.linear(item_eb)
|
|
625
|
+
else:
|
|
626
|
+
u = torch.unsqueeze(item_eb, dim=2)
|
|
627
|
+
item_eb_hat = torch.sum(self.w[:, :self.seq_len, :, :] * u, dim=3)
|
|
628
|
+
|
|
629
|
+
item_eb_hat = torch.reshape(item_eb_hat, (-1, self.seq_len, self.interest_num, self.embedding_dim))
|
|
630
|
+
item_eb_hat = torch.transpose(item_eb_hat, 1, 2).contiguous()
|
|
631
|
+
item_eb_hat = torch.reshape(item_eb_hat, (-1, self.interest_num, self.seq_len, self.embedding_dim))
|
|
632
|
+
|
|
633
|
+
if self.stop_grad:
|
|
634
|
+
item_eb_hat_iter = item_eb_hat.detach()
|
|
635
|
+
else:
|
|
636
|
+
item_eb_hat_iter = item_eb_hat
|
|
637
|
+
|
|
638
|
+
if self.bilinear_type > 0:
|
|
639
|
+
capsule_weight = torch.zeros(item_eb_hat.shape[0],
|
|
640
|
+
self.interest_num,
|
|
641
|
+
self.seq_len,
|
|
642
|
+
device=item_eb.device,
|
|
643
|
+
requires_grad=False)
|
|
644
|
+
else:
|
|
645
|
+
capsule_weight = torch.randn(item_eb_hat.shape[0],
|
|
646
|
+
self.interest_num,
|
|
647
|
+
self.seq_len,
|
|
648
|
+
device=item_eb.device,
|
|
649
|
+
requires_grad=False)
|
|
650
|
+
|
|
651
|
+
for i in range(self.routing_times): # 动态路由传播3次
|
|
652
|
+
atten_mask = torch.unsqueeze(mask, 1).repeat(1, self.interest_num, 1)
|
|
653
|
+
paddings = torch.zeros_like(atten_mask, dtype=torch.float)
|
|
654
|
+
|
|
655
|
+
capsule_softmax_weight = F.softmax(capsule_weight, dim=-1)
|
|
656
|
+
capsule_softmax_weight = torch.where(torch.eq(atten_mask, 0), paddings, capsule_softmax_weight)
|
|
657
|
+
capsule_softmax_weight = torch.unsqueeze(capsule_softmax_weight, 2)
|
|
658
|
+
|
|
659
|
+
if i < 2:
|
|
660
|
+
interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat_iter)
|
|
661
|
+
cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
|
|
662
|
+
scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
|
|
663
|
+
interest_capsule = scalar_factor * interest_capsule
|
|
664
|
+
|
|
665
|
+
delta_weight = torch.matmul(item_eb_hat_iter, torch.transpose(interest_capsule, 2, 3).contiguous())
|
|
666
|
+
delta_weight = torch.reshape(delta_weight, (-1, self.interest_num, self.seq_len))
|
|
667
|
+
capsule_weight = capsule_weight + delta_weight
|
|
668
|
+
else:
|
|
669
|
+
interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat)
|
|
670
|
+
cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
|
|
671
|
+
scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
|
|
672
|
+
interest_capsule = scalar_factor * interest_capsule
|
|
673
|
+
|
|
674
|
+
interest_capsule = torch.reshape(interest_capsule, (-1, self.interest_num, self.embedding_dim))
|
|
675
|
+
|
|
676
|
+
if self.relu_layer:
|
|
677
|
+
interest_capsule = self.relu(interest_capsule)
|
|
678
|
+
|
|
679
|
+
return interest_capsule
|
|
680
|
+
|
|
681
|
+
|
|
682
|
+
class FFM(nn.Module):
|
|
683
|
+
"""Field-aware Factorization Machine (Juan et al., 2016)."""
|
|
684
|
+
|
|
685
|
+
def __init__(self, num_fields, reduce_sum=True):
|
|
686
|
+
super().__init__()
|
|
687
|
+
self.num_fields = num_fields
|
|
688
|
+
self.reduce_sum = reduce_sum
|
|
689
|
+
|
|
690
|
+
def forward(self, x):
|
|
691
|
+
# compute (non-redundant) second order field-aware feature crossings
|
|
692
|
+
crossed_embeddings = []
|
|
693
|
+
for i in range(self.num_fields-1):
|
|
694
|
+
for j in range(i+1, self.num_fields):
|
|
695
|
+
crossed_embeddings.append(x[:, i, j, :] * x[:, j, i, :])
|
|
696
|
+
crossed_embeddings = torch.stack(crossed_embeddings, dim=1)
|
|
697
|
+
|
|
698
|
+
# if reduce_sum is true, the crossing operation is effectively inner product, other wise Hadamard-product
|
|
699
|
+
if self.reduce_sum:
|
|
700
|
+
crossed_embeddings = torch.sum(crossed_embeddings, dim=-1, keepdim=True)
|
|
701
|
+
return crossed_embeddings
|
|
702
|
+
|
|
703
|
+
|
|
704
|
+
class CEN(nn.Module):
|
|
705
|
+
"""Field-attentive interaction network from FAT-DeepFFM (Wang et al., 2020)."""
|
|
706
|
+
|
|
707
|
+
def __init__(self, embed_dim, num_field_crosses, reduction_ratio):
|
|
708
|
+
super().__init__()
|
|
709
|
+
|
|
710
|
+
# convolution weight (Eq.7 FAT-DeepFFM)
|
|
711
|
+
self.u = torch.nn.Parameter(torch.rand(num_field_crosses, embed_dim), requires_grad=True)
|
|
712
|
+
|
|
713
|
+
# two FC layers that computes the field attention
|
|
714
|
+
self.mlp_att = MLP(num_field_crosses, dims=[num_field_crosses//reduction_ratio, num_field_crosses], output_layer=False, activation="relu")
|
|
715
|
+
|
|
716
|
+
|
|
717
|
+
def forward(self, em):
|
|
718
|
+
# compute descriptor vector (Eq.7 FAT-DeepFFM), output shape [batch_size, num_field_crosses]
|
|
719
|
+
d = F.relu((self.u.squeeze(0) * em).sum(-1))
|
|
720
|
+
|
|
721
|
+
# compute field attention (Eq.9), output shape [batch_size, num_field_crosses]
|
|
722
|
+
s = self.mlp_att(d)
|
|
723
|
+
|
|
724
|
+
# rescale original embedding with field attention (Eq.10), output shape [batch_size, num_field_crosses, embed_dim]
|
|
725
|
+
aem = s.unsqueeze(-1) * em
|
|
726
|
+
return aem.flatten(start_dim=1)
|
|
727
|
+
|
|
728
|
+
|
|
729
|
+
class MultiHeadSelfAttention(nn.Module):
|
|
730
|
+
"""Multi-head self-attention layer from AutoInt (Song et al., 2019)."""
|
|
731
|
+
|
|
732
|
+
def __init__(self, embedding_dim, num_heads=2, dropout=0.0, use_residual=True):
|
|
733
|
+
super().__init__()
|
|
734
|
+
if embedding_dim % num_heads != 0:
|
|
735
|
+
raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by num_heads ({num_heads})")
|
|
736
|
+
|
|
737
|
+
self.embedding_dim = embedding_dim
|
|
738
|
+
self.num_heads = num_heads
|
|
739
|
+
self.head_dim = embedding_dim // num_heads
|
|
740
|
+
self.use_residual = use_residual
|
|
741
|
+
|
|
742
|
+
self.W_Q = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
743
|
+
self.W_K = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
744
|
+
self.W_V = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
745
|
+
|
|
746
|
+
if self.use_residual:
|
|
747
|
+
self.W_Res = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
748
|
+
|
|
749
|
+
self.dropout = nn.Dropout(dropout)
|
|
750
|
+
|
|
751
|
+
def forward(self, x):
|
|
752
|
+
"""
|
|
753
|
+
Args:
|
|
754
|
+
x: [batch_size, num_fields, embedding_dim]
|
|
755
|
+
Returns:
|
|
756
|
+
output: [batch_size, num_fields, embedding_dim]
|
|
757
|
+
"""
|
|
758
|
+
batch_size, num_fields, _ = x.shape
|
|
759
|
+
|
|
760
|
+
# Linear projections
|
|
761
|
+
Q = self.W_Q(x) # [batch_size, num_fields, embedding_dim]
|
|
762
|
+
K = self.W_K(x)
|
|
763
|
+
V = self.W_V(x)
|
|
764
|
+
|
|
765
|
+
# Split into multiple heads: [batch_size, num_heads, num_fields, head_dim]
|
|
766
|
+
Q = Q.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
|
|
767
|
+
K = K.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
|
|
768
|
+
V = V.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
|
|
769
|
+
|
|
770
|
+
# Attention scores
|
|
771
|
+
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
|
|
772
|
+
attention_weights = F.softmax(scores, dim=-1)
|
|
773
|
+
attention_weights = self.dropout(attention_weights)
|
|
774
|
+
|
|
775
|
+
# Apply attention to values
|
|
776
|
+
attention_output = torch.matmul(attention_weights, V) # [batch_size, num_heads, num_fields, head_dim]
|
|
777
|
+
|
|
778
|
+
# Concatenate heads
|
|
779
|
+
attention_output = attention_output.transpose(1, 2).contiguous()
|
|
780
|
+
attention_output = attention_output.view(batch_size, num_fields, self.embedding_dim)
|
|
781
|
+
|
|
782
|
+
# Residual connection
|
|
783
|
+
if self.use_residual:
|
|
784
|
+
output = attention_output + self.W_Res(x)
|
|
785
|
+
else:
|
|
786
|
+
output = attention_output
|
|
787
|
+
|
|
788
|
+
output = F.relu(output)
|
|
789
|
+
|
|
790
|
+
return output
|
|
791
|
+
|
|
792
|
+
|
|
793
|
+
class AttentionPoolingLayer(nn.Module):
|
|
794
|
+
"""
|
|
795
|
+
Attention pooling layer for DIN/DIEN
|
|
796
|
+
Computes attention weights between query (candidate item) and keys (user behavior sequence)
|
|
797
|
+
"""
|
|
798
|
+
|
|
799
|
+
def __init__(self, embedding_dim, hidden_units=[80, 40], activation='sigmoid', use_softmax=True):
|
|
800
|
+
super().__init__()
|
|
801
|
+
self.embedding_dim = embedding_dim
|
|
802
|
+
self.use_softmax = use_softmax
|
|
803
|
+
|
|
804
|
+
# Build attention network
|
|
805
|
+
# Input: [query, key, query-key, query*key] -> 4 * embedding_dim
|
|
806
|
+
input_dim = 4 * embedding_dim
|
|
807
|
+
layers = []
|
|
808
|
+
|
|
809
|
+
for hidden_unit in hidden_units:
|
|
810
|
+
layers.append(nn.Linear(input_dim, hidden_unit))
|
|
811
|
+
layers.append(activation_layer(activation))
|
|
812
|
+
input_dim = hidden_unit
|
|
813
|
+
|
|
814
|
+
layers.append(nn.Linear(input_dim, 1))
|
|
815
|
+
self.attention_net = nn.Sequential(*layers)
|
|
816
|
+
|
|
817
|
+
def forward(self, query, keys, keys_length=None, mask=None):
|
|
818
|
+
"""
|
|
819
|
+
Args:
|
|
820
|
+
query: [batch_size, embedding_dim] - candidate item embedding
|
|
821
|
+
keys: [batch_size, seq_len, embedding_dim] - user behavior sequence
|
|
822
|
+
keys_length: [batch_size] - actual length of each sequence (optional)
|
|
823
|
+
mask: [batch_size, seq_len, 1] - mask for padding (optional)
|
|
824
|
+
Returns:
|
|
825
|
+
output: [batch_size, embedding_dim] - attention pooled representation
|
|
826
|
+
"""
|
|
827
|
+
batch_size, seq_len, emb_dim = keys.shape
|
|
828
|
+
|
|
829
|
+
# Expand query to match sequence length: [batch_size, seq_len, embedding_dim]
|
|
830
|
+
query_expanded = query.unsqueeze(1).expand(-1, seq_len, -1)
|
|
831
|
+
|
|
832
|
+
# Compute attention features: [query, key, query-key, query*key]
|
|
833
|
+
attention_input = torch.cat([
|
|
834
|
+
query_expanded,
|
|
835
|
+
keys,
|
|
836
|
+
query_expanded - keys,
|
|
837
|
+
query_expanded * keys
|
|
838
|
+
], dim=-1) # [batch_size, seq_len, 4*embedding_dim]
|
|
839
|
+
|
|
840
|
+
# Compute attention scores
|
|
841
|
+
attention_scores = self.attention_net(attention_input) # [batch_size, seq_len, 1]
|
|
842
|
+
|
|
843
|
+
# Apply mask if provided
|
|
844
|
+
if mask is not None:
|
|
845
|
+
attention_scores = attention_scores.masked_fill(mask == 0, -1e9)
|
|
846
|
+
|
|
847
|
+
# Apply softmax to get attention weights
|
|
848
|
+
if self.use_softmax:
|
|
849
|
+
attention_weights = F.softmax(attention_scores, dim=1) # [batch_size, seq_len, 1]
|
|
850
|
+
else:
|
|
851
|
+
attention_weights = attention_scores
|
|
852
|
+
|
|
853
|
+
# Weighted sum of keys
|
|
854
|
+
output = torch.sum(attention_weights * keys, dim=1) # [batch_size, embedding_dim]
|
|
855
|
+
|
|
856
|
+
return output
|
|
857
|
+
|
|
858
|
+
|
|
859
|
+
class DynamicGRU(nn.Module):
|
|
860
|
+
"""Dynamic GRU unit with auxiliary loss path from DIEN (Zhou et al., 2019)."""
|
|
861
|
+
"""
|
|
862
|
+
GRU with dynamic routing for DIEN
|
|
863
|
+
"""
|
|
864
|
+
|
|
865
|
+
def __init__(self, input_size, hidden_size, bias=True):
|
|
866
|
+
super().__init__()
|
|
867
|
+
self.input_size = input_size
|
|
868
|
+
self.hidden_size = hidden_size
|
|
869
|
+
|
|
870
|
+
# GRU parameters
|
|
871
|
+
self.weight_ih = nn.Parameter(torch.randn(3 * hidden_size, input_size))
|
|
872
|
+
self.weight_hh = nn.Parameter(torch.randn(3 * hidden_size, hidden_size))
|
|
873
|
+
if bias:
|
|
874
|
+
self.bias_ih = nn.Parameter(torch.randn(3 * hidden_size))
|
|
875
|
+
self.bias_hh = nn.Parameter(torch.randn(3 * hidden_size))
|
|
876
|
+
else:
|
|
877
|
+
self.register_parameter('bias_ih', None)
|
|
878
|
+
self.register_parameter('bias_hh', None)
|
|
879
|
+
|
|
880
|
+
self.reset_parameters()
|
|
881
|
+
|
|
882
|
+
def reset_parameters(self):
|
|
883
|
+
std = 1.0 / (self.hidden_size) ** 0.5
|
|
884
|
+
for weight in self.parameters():
|
|
885
|
+
weight.data.uniform_(-std, std)
|
|
886
|
+
|
|
887
|
+
def forward(self, x, att_scores=None):
|
|
888
|
+
"""
|
|
889
|
+
Args:
|
|
890
|
+
x: [batch_size, seq_len, input_size]
|
|
891
|
+
att_scores: [batch_size, seq_len] - attention scores for auxiliary loss
|
|
892
|
+
Returns:
|
|
893
|
+
output: [batch_size, seq_len, hidden_size]
|
|
894
|
+
hidden: [batch_size, hidden_size] - final hidden state
|
|
895
|
+
"""
|
|
896
|
+
batch_size, seq_len, _ = x.shape
|
|
897
|
+
|
|
898
|
+
# Initialize hidden state
|
|
899
|
+
h = torch.zeros(batch_size, self.hidden_size, device=x.device)
|
|
900
|
+
|
|
901
|
+
outputs = []
|
|
902
|
+
for t in range(seq_len):
|
|
903
|
+
x_t = x[:, t, :] # [batch_size, input_size]
|
|
904
|
+
|
|
905
|
+
# GRU computation
|
|
906
|
+
gi = F.linear(x_t, self.weight_ih, self.bias_ih)
|
|
907
|
+
gh = F.linear(h, self.weight_hh, self.bias_hh)
|
|
908
|
+
i_r, i_i, i_n = gi.chunk(3, 1)
|
|
909
|
+
h_r, h_i, h_n = gh.chunk(3, 1)
|
|
910
|
+
|
|
911
|
+
resetgate = torch.sigmoid(i_r + h_r)
|
|
912
|
+
inputgate = torch.sigmoid(i_i + h_i)
|
|
913
|
+
newgate = torch.tanh(i_n + resetgate * h_n)
|
|
914
|
+
h = newgate + inputgate * (h - newgate)
|
|
915
|
+
|
|
916
|
+
outputs.append(h.unsqueeze(1))
|
|
917
|
+
|
|
918
|
+
output = torch.cat(outputs, dim=1) # [batch_size, seq_len, hidden_size]
|
|
919
|
+
|
|
920
|
+
return output, h
|
|
921
|
+
|
|
922
|
+
|
|
923
|
+
class AUGRU(nn.Module):
|
|
924
|
+
"""Attention-aware GRU update gate used in DIEN (Zhou et al., 2019)."""
|
|
925
|
+
"""
|
|
926
|
+
Attention-based GRU for DIEN
|
|
927
|
+
Uses attention scores to weight the update of hidden states
|
|
928
|
+
"""
|
|
929
|
+
|
|
930
|
+
def __init__(self, input_size, hidden_size, bias=True):
|
|
931
|
+
super().__init__()
|
|
932
|
+
self.input_size = input_size
|
|
933
|
+
self.hidden_size = hidden_size
|
|
934
|
+
|
|
935
|
+
self.weight_ih = nn.Parameter(torch.randn(3 * hidden_size, input_size))
|
|
936
|
+
self.weight_hh = nn.Parameter(torch.randn(3 * hidden_size, hidden_size))
|
|
937
|
+
if bias:
|
|
938
|
+
self.bias_ih = nn.Parameter(torch.randn(3 * hidden_size))
|
|
939
|
+
self.bias_hh = nn.Parameter(torch.randn(3 * hidden_size))
|
|
940
|
+
else:
|
|
941
|
+
self.register_parameter('bias_ih', None)
|
|
942
|
+
self.register_parameter('bias_hh', None)
|
|
943
|
+
|
|
944
|
+
self.reset_parameters()
|
|
945
|
+
|
|
946
|
+
def reset_parameters(self):
|
|
947
|
+
std = 1.0 / (self.hidden_size) ** 0.5
|
|
948
|
+
for weight in self.parameters():
|
|
949
|
+
weight.data.uniform_(-std, std)
|
|
950
|
+
|
|
951
|
+
def forward(self, x, att_scores):
|
|
952
|
+
"""
|
|
953
|
+
Args:
|
|
954
|
+
x: [batch_size, seq_len, input_size]
|
|
955
|
+
att_scores: [batch_size, seq_len, 1] - attention scores
|
|
956
|
+
Returns:
|
|
957
|
+
output: [batch_size, seq_len, hidden_size]
|
|
958
|
+
hidden: [batch_size, hidden_size] - final hidden state
|
|
959
|
+
"""
|
|
960
|
+
batch_size, seq_len, _ = x.shape
|
|
961
|
+
|
|
962
|
+
h = torch.zeros(batch_size, self.hidden_size, device=x.device)
|
|
963
|
+
|
|
964
|
+
outputs = []
|
|
965
|
+
for t in range(seq_len):
|
|
966
|
+
x_t = x[:, t, :] # [batch_size, input_size]
|
|
967
|
+
att_t = att_scores[:, t, :] # [batch_size, 1]
|
|
968
|
+
|
|
969
|
+
gi = F.linear(x_t, self.weight_ih, self.bias_ih)
|
|
970
|
+
gh = F.linear(h, self.weight_hh, self.bias_hh)
|
|
971
|
+
i_r, i_i, i_n = gi.chunk(3, 1)
|
|
972
|
+
h_r, h_i, h_n = gh.chunk(3, 1)
|
|
973
|
+
|
|
974
|
+
resetgate = torch.sigmoid(i_r + h_r)
|
|
975
|
+
inputgate = torch.sigmoid(i_i + h_i)
|
|
976
|
+
newgate = torch.tanh(i_n + resetgate * h_n)
|
|
977
|
+
|
|
978
|
+
# Use attention score to control update
|
|
979
|
+
h = (1 - att_t) * h + att_t * newgate
|
|
980
|
+
|
|
981
|
+
outputs.append(h.unsqueeze(1))
|
|
982
|
+
|
|
983
|
+
output = torch.cat(outputs, dim=1)
|
|
984
|
+
|
|
985
|
+
return output, h
|