nextrec 0.2.6__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +4 -8
- nextrec/basic/callback.py +1 -1
- nextrec/basic/features.py +33 -25
- nextrec/basic/layers.py +164 -601
- nextrec/basic/loggers.py +3 -4
- nextrec/basic/metrics.py +39 -115
- nextrec/basic/model.py +248 -174
- nextrec/basic/session.py +1 -5
- nextrec/data/__init__.py +12 -0
- nextrec/data/data_utils.py +3 -27
- nextrec/data/dataloader.py +26 -34
- nextrec/data/preprocessor.py +2 -1
- nextrec/loss/listwise.py +6 -4
- nextrec/loss/loss_utils.py +10 -6
- nextrec/loss/pairwise.py +5 -3
- nextrec/loss/pointwise.py +7 -13
- nextrec/models/match/mind.py +110 -1
- nextrec/models/multi_task/esmm.py +46 -27
- nextrec/models/multi_task/mmoe.py +48 -30
- nextrec/models/multi_task/ple.py +156 -141
- nextrec/models/multi_task/poso.py +413 -0
- nextrec/models/multi_task/share_bottom.py +43 -26
- nextrec/models/ranking/__init__.py +2 -0
- nextrec/models/ranking/autoint.py +1 -1
- nextrec/models/ranking/dcn.py +20 -1
- nextrec/models/ranking/dcn_v2.py +84 -0
- nextrec/models/ranking/deepfm.py +44 -18
- nextrec/models/ranking/dien.py +130 -27
- nextrec/models/ranking/masknet.py +13 -67
- nextrec/models/ranking/widedeep.py +39 -18
- nextrec/models/ranking/xdeepfm.py +34 -1
- nextrec/utils/common.py +26 -1
- nextrec-0.3.1.dist-info/METADATA +306 -0
- nextrec-0.3.1.dist-info/RECORD +56 -0
- {nextrec-0.2.6.dist-info → nextrec-0.3.1.dist-info}/WHEEL +1 -1
- nextrec-0.2.6.dist-info/METADATA +0 -281
- nextrec-0.2.6.dist-info/RECORD +0 -54
- {nextrec-0.2.6.dist-info → nextrec-0.3.1.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/layers.py
CHANGED
|
@@ -1,24 +1,22 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Layer implementations used across NextRec models.
|
|
3
3
|
|
|
4
|
-
Date: create on 27/10/2025
|
|
5
|
-
|
|
4
|
+
Date: create on 27/10/2025
|
|
5
|
+
Checkpoint: edit on 29/11/2025
|
|
6
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
6
7
|
"""
|
|
7
|
-
|
|
8
8
|
from __future__ import annotations
|
|
9
9
|
|
|
10
10
|
from itertools import combinations
|
|
11
|
-
from
|
|
11
|
+
from collections import OrderedDict
|
|
12
12
|
|
|
13
13
|
import torch
|
|
14
14
|
import torch.nn as nn
|
|
15
15
|
import torch.nn.functional as F
|
|
16
16
|
|
|
17
|
-
from nextrec.basic.activation import activation_layer
|
|
18
17
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
19
18
|
from nextrec.utils.initializer import get_initializer
|
|
20
|
-
|
|
21
|
-
Feature = Union[DenseFeature, SparseFeature, SequenceFeature]
|
|
19
|
+
from nextrec.basic.activation import activation_layer
|
|
22
20
|
|
|
23
21
|
__all__ = [
|
|
24
22
|
"PredictionLayer",
|
|
@@ -30,57 +28,38 @@ __all__ = [
|
|
|
30
28
|
"SumPooling",
|
|
31
29
|
"MLP",
|
|
32
30
|
"FM",
|
|
33
|
-
"FFM",
|
|
34
|
-
"CEN",
|
|
35
|
-
"CIN",
|
|
36
31
|
"CrossLayer",
|
|
37
|
-
"CrossNetwork",
|
|
38
|
-
"CrossNetV2",
|
|
39
|
-
"CrossNetMix",
|
|
40
32
|
"SENETLayer",
|
|
41
33
|
"BiLinearInteractionLayer",
|
|
42
|
-
"MultiInterestSA",
|
|
43
|
-
"CapsuleNetwork",
|
|
44
34
|
"MultiHeadSelfAttention",
|
|
45
35
|
"AttentionPoolingLayer",
|
|
46
|
-
"DynamicGRU",
|
|
47
|
-
"AUGRU",
|
|
48
36
|
]
|
|
49
37
|
|
|
50
|
-
|
|
51
38
|
class PredictionLayer(nn.Module):
|
|
52
39
|
def __init__(
|
|
53
40
|
self,
|
|
54
|
-
task_type:
|
|
55
|
-
task_dims:
|
|
41
|
+
task_type: str | list[str] = "binary",
|
|
42
|
+
task_dims: int | list[int] | None = None,
|
|
56
43
|
use_bias: bool = True,
|
|
57
44
|
return_logits: bool = False,
|
|
58
45
|
):
|
|
59
46
|
super().__init__()
|
|
60
|
-
|
|
61
47
|
if isinstance(task_type, str):
|
|
62
48
|
self.task_types = [task_type]
|
|
63
49
|
else:
|
|
64
50
|
self.task_types = list(task_type)
|
|
65
|
-
|
|
66
51
|
if len(self.task_types) == 0:
|
|
67
52
|
raise ValueError("At least one task_type must be specified.")
|
|
68
|
-
|
|
69
53
|
if task_dims is None:
|
|
70
54
|
dims = [1] * len(self.task_types)
|
|
71
55
|
elif isinstance(task_dims, int):
|
|
72
56
|
dims = [task_dims]
|
|
73
57
|
else:
|
|
74
58
|
dims = list(task_dims)
|
|
75
|
-
|
|
76
59
|
if len(dims) not in (1, len(self.task_types)):
|
|
77
|
-
raise ValueError(
|
|
78
|
-
"task_dims must be None, a single int (shared), or a sequence of the same length as task_type."
|
|
79
|
-
)
|
|
80
|
-
|
|
60
|
+
raise ValueError("[PredictionLayer Error]: task_dims must be None, a single int (shared), or a sequence of the same length as task_type.")
|
|
81
61
|
if len(dims) == 1 and len(self.task_types) > 1:
|
|
82
62
|
dims = dims * len(self.task_types)
|
|
83
|
-
|
|
84
63
|
self.task_dims = dims
|
|
85
64
|
self.total_dim = sum(self.task_dims)
|
|
86
65
|
self.return_logits = return_logits
|
|
@@ -93,7 +72,6 @@ class PredictionLayer(nn.Module):
|
|
|
93
72
|
raise ValueError("Each task dimension must be >= 1.")
|
|
94
73
|
self._task_slices.append((start, start + dim))
|
|
95
74
|
start += dim
|
|
96
|
-
|
|
97
75
|
if use_bias:
|
|
98
76
|
self.bias = nn.Parameter(torch.zeros(self.total_dim))
|
|
99
77
|
else:
|
|
@@ -101,25 +79,18 @@ class PredictionLayer(nn.Module):
|
|
|
101
79
|
|
|
102
80
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
103
81
|
if x.dim() == 1:
|
|
104
|
-
x = x.unsqueeze(
|
|
105
|
-
|
|
82
|
+
x = x.unsqueeze(0) # (1 * total_dim)
|
|
106
83
|
if x.shape[-1] != self.total_dim:
|
|
107
|
-
raise ValueError(
|
|
108
|
-
f"Input last dimension ({x.shape[-1]}) does not match expected total dimension ({self.total_dim})."
|
|
109
|
-
)
|
|
110
|
-
|
|
84
|
+
raise ValueError(f"[PredictionLayer Error]: Input last dimension ({x.shape[-1]}) does not match expected total dimension ({self.total_dim}).")
|
|
111
85
|
logits = x if self.bias is None else x + self.bias
|
|
112
|
-
outputs
|
|
113
|
-
|
|
86
|
+
outputs = []
|
|
114
87
|
for task_type, (start, end) in zip(self.task_types, self._task_slices):
|
|
115
|
-
task_logits = logits[..., start:end]
|
|
88
|
+
task_logits = logits[..., start:end] # Extract logits for the current task
|
|
116
89
|
if self.return_logits:
|
|
117
90
|
outputs.append(task_logits)
|
|
118
91
|
continue
|
|
119
|
-
|
|
120
92
|
activation = self._get_activation(task_type)
|
|
121
93
|
outputs.append(activation(task_logits))
|
|
122
|
-
|
|
123
94
|
result = torch.cat(outputs, dim=-1)
|
|
124
95
|
if result.shape[-1] == 1:
|
|
125
96
|
result = result.squeeze(-1)
|
|
@@ -127,17 +98,16 @@ class PredictionLayer(nn.Module):
|
|
|
127
98
|
|
|
128
99
|
def _get_activation(self, task_type: str):
|
|
129
100
|
task = task_type.lower()
|
|
130
|
-
if task
|
|
101
|
+
if task == 'binary':
|
|
131
102
|
return torch.sigmoid
|
|
132
|
-
if task
|
|
103
|
+
if task == 'regression':
|
|
133
104
|
return lambda x: x
|
|
134
|
-
if task
|
|
105
|
+
if task == 'multiclass':
|
|
135
106
|
return lambda x: torch.softmax(x, dim=-1)
|
|
136
|
-
raise ValueError(f"Unsupported task_type '{task_type}'.")
|
|
137
|
-
|
|
107
|
+
raise ValueError(f"[PredictionLayer Error]: Unsupported task_type '{task_type}'.")
|
|
138
108
|
|
|
139
109
|
class EmbeddingLayer(nn.Module):
|
|
140
|
-
def __init__(self, features:
|
|
110
|
+
def __init__(self, features: list):
|
|
141
111
|
super().__init__()
|
|
142
112
|
self.features = list(features)
|
|
143
113
|
self.embed_dict = nn.ModuleDict()
|
|
@@ -148,23 +118,22 @@ class EmbeddingLayer(nn.Module):
|
|
|
148
118
|
if isinstance(feature, (SparseFeature, SequenceFeature)):
|
|
149
119
|
if feature.embedding_name in self.embed_dict:
|
|
150
120
|
continue
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
padding_idx=feature.padding_idx
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
init_type=feature.init_type,
|
|
161
|
-
|
|
162
|
-
param=feature.init_params,
|
|
163
|
-
)
|
|
164
|
-
initialization(embedding.weight)
|
|
121
|
+
if getattr(feature, "pretrained_weight", None) is not None:
|
|
122
|
+
weight = feature.pretrained_weight # type: ignore[assignment]
|
|
123
|
+
if weight.shape != (feature.vocab_size, feature.embedding_dim): # type: ignore[assignment]
|
|
124
|
+
raise ValueError(f"[EmbeddingLayer Error]: Pretrained weight for '{feature.embedding_name}' has shape {weight.shape}, expected ({feature.vocab_size}, {feature.embedding_dim}).") # type: ignore[assignment]
|
|
125
|
+
embedding = nn.Embedding.from_pretrained(embeddings=weight, freeze=feature.freeze_pretrained, padding_idx=feature.padding_idx) # type: ignore[assignment]
|
|
126
|
+
embedding.weight.requires_grad = feature.trainable and not feature.freeze_pretrained # type: ignore[assignment]
|
|
127
|
+
else:
|
|
128
|
+
embedding = nn.Embedding(num_embeddings=feature.vocab_size, embedding_dim=feature.embedding_dim, padding_idx=feature.padding_idx)
|
|
129
|
+
embedding.weight.requires_grad = feature.trainable
|
|
130
|
+
initialization = get_initializer(init_type=feature.init_type, activation="linear", param=feature.init_params)
|
|
131
|
+
initialization(embedding.weight)
|
|
165
132
|
self.embed_dict[feature.embedding_name] = embedding
|
|
166
|
-
|
|
167
133
|
elif isinstance(feature, DenseFeature):
|
|
134
|
+
if not feature.use_embedding:
|
|
135
|
+
self.dense_input_dims[feature.name] = max(int(getattr(feature, "input_dim", 1)), 1)
|
|
136
|
+
continue
|
|
168
137
|
if feature.name in self.dense_transforms:
|
|
169
138
|
continue
|
|
170
139
|
in_dim = max(int(getattr(feature, "input_dim", 1)), 1)
|
|
@@ -174,15 +143,14 @@ class EmbeddingLayer(nn.Module):
|
|
|
174
143
|
nn.init.zeros_(dense_linear.bias)
|
|
175
144
|
self.dense_transforms[feature.name] = dense_linear
|
|
176
145
|
self.dense_input_dims[feature.name] = in_dim
|
|
177
|
-
|
|
178
146
|
else:
|
|
179
|
-
raise TypeError(f"Unsupported feature type: {type(feature)}")
|
|
147
|
+
raise TypeError(f"[EmbeddingLayer Error]: Unsupported feature type: {type(feature)}")
|
|
180
148
|
self.output_dim = self._compute_output_dim()
|
|
181
149
|
|
|
182
150
|
def forward(
|
|
183
151
|
self,
|
|
184
152
|
x: dict[str, torch.Tensor],
|
|
185
|
-
features:
|
|
153
|
+
features: list[object],
|
|
186
154
|
squeeze_dim: bool = False,
|
|
187
155
|
) -> torch.Tensor:
|
|
188
156
|
sparse_embeds: list[torch.Tensor] = []
|
|
@@ -208,8 +176,7 @@ class EmbeddingLayer(nn.Module):
|
|
|
208
176
|
elif feature.combiner == "concat":
|
|
209
177
|
pooling_layer = ConcatPooling()
|
|
210
178
|
else:
|
|
211
|
-
raise ValueError(f"Unknown combiner for {feature.name}: {feature.combiner}")
|
|
212
|
-
|
|
179
|
+
raise ValueError(f"[EmbeddingLayer Error]: Unknown combiner for {feature.name}: {feature.combiner}")
|
|
213
180
|
feature_mask = InputMask()(x, feature, seq_input)
|
|
214
181
|
sparse_embeds.append(pooling_layer(seq_emb, feature_mask).unsqueeze(1))
|
|
215
182
|
|
|
@@ -223,107 +190,116 @@ class EmbeddingLayer(nn.Module):
|
|
|
223
190
|
pieces.append(torch.cat(flattened_sparse, dim=1))
|
|
224
191
|
if dense_embeds:
|
|
225
192
|
pieces.append(torch.cat(dense_embeds, dim=1))
|
|
226
|
-
|
|
227
193
|
if not pieces:
|
|
228
|
-
raise ValueError("No input features found for EmbeddingLayer.")
|
|
229
|
-
|
|
194
|
+
raise ValueError("[EmbeddingLayer Error]: No input features found for EmbeddingLayer.")
|
|
230
195
|
return pieces[0] if len(pieces) == 1 else torch.cat(pieces, dim=1)
|
|
231
|
-
|
|
196
|
+
|
|
232
197
|
# squeeze_dim=False requires embeddings with identical last dimension
|
|
233
198
|
output_embeddings = list(sparse_embeds)
|
|
234
199
|
if dense_embeds:
|
|
235
|
-
target_dim = None
|
|
236
200
|
if output_embeddings:
|
|
237
201
|
target_dim = output_embeddings[0].shape[-1]
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
202
|
+
for emb in dense_embeds:
|
|
203
|
+
if emb.shape[-1] != target_dim:
|
|
204
|
+
raise ValueError(f"[EmbeddingLayer Error]: squeeze_dim=False requires all dense feature dimensions to match the embedding dimension of sparse/sequence features ({target_dim}), but got {emb.shape[-1]}.")
|
|
205
|
+
output_embeddings.extend(emb.unsqueeze(1) for emb in dense_embeds)
|
|
206
|
+
else:
|
|
207
|
+
dims = {emb.shape[-1] for emb in dense_embeds}
|
|
208
|
+
if len(dims) != 1:
|
|
209
|
+
raise ValueError(f"[EmbeddingLayer Error]: squeeze_dim=False requires all dense features to have identical dimensions when no sparse/sequence features are present, but got dimensions {dims}.")
|
|
210
|
+
output_embeddings = [emb.unsqueeze(1) for emb in dense_embeds]
|
|
247
211
|
if not output_embeddings:
|
|
248
|
-
raise ValueError(
|
|
249
|
-
"squeeze_dim=False requires at least one sparse/sequence feature or "
|
|
250
|
-
"dense features with identical projected dimensions."
|
|
251
|
-
)
|
|
252
|
-
|
|
212
|
+
raise ValueError("[EmbeddingLayer Error]: squeeze_dim=False requires at least one sparse/sequence feature or dense features with identical projected dimensions.")
|
|
253
213
|
return torch.cat(output_embeddings, dim=1)
|
|
254
214
|
|
|
255
215
|
def _project_dense(self, feature: DenseFeature, x: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
256
216
|
if feature.name not in x:
|
|
257
|
-
raise KeyError(f"Dense feature '{feature.name}' is missing from input.")
|
|
258
|
-
|
|
217
|
+
raise KeyError(f"[EmbeddingLayer Error]:Dense feature '{feature.name}' is missing from input.")
|
|
259
218
|
value = x[feature.name].float()
|
|
260
219
|
if value.dim() == 1:
|
|
261
220
|
value = value.unsqueeze(-1)
|
|
262
221
|
else:
|
|
263
222
|
value = value.view(value.size(0), -1)
|
|
264
|
-
|
|
265
|
-
dense_layer = self.dense_transforms[feature.name]
|
|
266
|
-
expected_in_dim = self.dense_input_dims[feature.name]
|
|
223
|
+
expected_in_dim = self.dense_input_dims.get(feature.name, max(int(getattr(feature, "input_dim", 1)), 1))
|
|
267
224
|
if value.shape[1] != expected_in_dim:
|
|
268
|
-
raise ValueError(
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
225
|
+
raise ValueError(f"[EmbeddingLayer Error]:Dense feature '{feature.name}' expects {expected_in_dim} inputs but got {value.shape[1]}.")
|
|
226
|
+
if not feature.use_embedding:
|
|
227
|
+
return value
|
|
228
|
+
dense_layer = self.dense_transforms[feature.name]
|
|
273
229
|
return dense_layer(value)
|
|
274
230
|
|
|
275
|
-
def _compute_output_dim(self):
|
|
276
|
-
|
|
231
|
+
def _compute_output_dim(self, features: list[DenseFeature | SequenceFeature | SparseFeature] | None = None) -> int:
|
|
232
|
+
"""
|
|
233
|
+
Compute flattened embedding dimension for provided features or all tracked features.
|
|
234
|
+
Deduplicates by feature name to avoid double-counting shared embeddings.
|
|
235
|
+
"""
|
|
236
|
+
candidates = list(features) if features is not None else self.features
|
|
237
|
+
unique_feats = OrderedDict((feat.name, feat) for feat in candidates) # type: ignore[assignment]
|
|
238
|
+
dim = 0
|
|
239
|
+
for feat in unique_feats.values():
|
|
240
|
+
if isinstance(feat, DenseFeature):
|
|
241
|
+
in_dim = max(int(getattr(feat, "input_dim", 1)), 1)
|
|
242
|
+
emb_dim = getattr(feat, "embedding_dim", None)
|
|
243
|
+
out_dim = max(int(emb_dim), 1) if emb_dim else in_dim
|
|
244
|
+
dim += out_dim
|
|
245
|
+
elif isinstance(feat, SequenceFeature) and feat.combiner == "concat":
|
|
246
|
+
dim += feat.embedding_dim * feat.max_len
|
|
247
|
+
else:
|
|
248
|
+
dim += feat.embedding_dim # type: ignore[assignment]
|
|
249
|
+
return dim
|
|
250
|
+
|
|
251
|
+
def get_input_dim(self, features: list[object] | None = None) -> int:
|
|
252
|
+
return self._compute_output_dim(features) # type: ignore[assignment]
|
|
253
|
+
|
|
254
|
+
@property
|
|
255
|
+
def input_dim(self) -> int:
|
|
256
|
+
return self.output_dim
|
|
277
257
|
|
|
278
258
|
class InputMask(nn.Module):
|
|
279
259
|
"""Utility module to build sequence masks for pooling layers."""
|
|
280
|
-
|
|
281
260
|
def __init__(self):
|
|
282
261
|
super().__init__()
|
|
283
262
|
|
|
284
|
-
def forward(self, x,
|
|
285
|
-
values = seq_tensor if seq_tensor is not None else x[
|
|
286
|
-
if
|
|
287
|
-
mask = (values.long() !=
|
|
263
|
+
def forward(self, x: dict[str, torch.Tensor], feature: SequenceFeature, seq_tensor: torch.Tensor | None = None):
|
|
264
|
+
values = seq_tensor if seq_tensor is not None else x[feature.name]
|
|
265
|
+
if feature.padding_idx is not None:
|
|
266
|
+
mask = (values.long() != feature.padding_idx)
|
|
288
267
|
else:
|
|
289
268
|
mask = (values.long() != 0)
|
|
290
269
|
if mask.dim() == 1:
|
|
291
270
|
mask = mask.unsqueeze(-1)
|
|
292
271
|
return mask.unsqueeze(1).float()
|
|
293
272
|
|
|
294
|
-
|
|
295
273
|
class LR(nn.Module):
|
|
296
274
|
"""Wide component from Wide&Deep (Cheng et al., 2016)."""
|
|
297
|
-
|
|
298
|
-
|
|
275
|
+
def __init__(
|
|
276
|
+
self,
|
|
277
|
+
input_dim: int,
|
|
278
|
+
sigmoid: bool = False):
|
|
299
279
|
super().__init__()
|
|
300
280
|
self.sigmoid = sigmoid
|
|
301
281
|
self.fc = nn.Linear(input_dim, 1, bias=True)
|
|
302
282
|
|
|
303
|
-
def forward(self, x):
|
|
283
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
304
284
|
if self.sigmoid:
|
|
305
285
|
return torch.sigmoid(self.fc(x))
|
|
306
286
|
else:
|
|
307
287
|
return self.fc(x)
|
|
308
288
|
|
|
309
|
-
|
|
310
289
|
class ConcatPooling(nn.Module):
|
|
311
290
|
"""Concatenates sequence embeddings along the temporal dimension."""
|
|
312
|
-
|
|
313
291
|
def __init__(self):
|
|
314
292
|
super().__init__()
|
|
315
293
|
|
|
316
|
-
def forward(self, x, mask=None):
|
|
294
|
+
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
|
|
317
295
|
return x.flatten(start_dim=1, end_dim=2)
|
|
318
296
|
|
|
319
|
-
|
|
320
297
|
class AveragePooling(nn.Module):
|
|
321
298
|
"""Mean pooling with optional padding mask."""
|
|
322
|
-
|
|
323
299
|
def __init__(self):
|
|
324
300
|
super().__init__()
|
|
325
301
|
|
|
326
|
-
def forward(self, x, mask=None):
|
|
302
|
+
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
|
|
327
303
|
if mask is None:
|
|
328
304
|
return torch.mean(x, dim=1)
|
|
329
305
|
else:
|
|
@@ -331,24 +307,26 @@ class AveragePooling(nn.Module):
|
|
|
331
307
|
non_padding_length = mask.sum(dim=-1)
|
|
332
308
|
return sum_pooling_matrix / (non_padding_length.float() + 1e-16)
|
|
333
309
|
|
|
334
|
-
|
|
335
310
|
class SumPooling(nn.Module):
|
|
336
311
|
"""Sum pooling with optional padding mask."""
|
|
337
|
-
|
|
338
312
|
def __init__(self):
|
|
339
313
|
super().__init__()
|
|
340
314
|
|
|
341
|
-
def forward(self, x, mask=None):
|
|
315
|
+
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
|
|
342
316
|
if mask is None:
|
|
343
317
|
return torch.sum(x, dim=1)
|
|
344
318
|
else:
|
|
345
319
|
return torch.bmm(mask, x).squeeze(1)
|
|
346
320
|
|
|
347
|
-
|
|
348
321
|
class MLP(nn.Module):
|
|
349
322
|
"""Stacked fully connected layers used in the deep component."""
|
|
350
|
-
|
|
351
|
-
|
|
323
|
+
def __init__(
|
|
324
|
+
self,
|
|
325
|
+
input_dim: int,
|
|
326
|
+
output_layer: bool = True,
|
|
327
|
+
dims: list[int] | None = None,
|
|
328
|
+
dropout: float = 0.0,
|
|
329
|
+
activation: str = "relu"):
|
|
352
330
|
super().__init__()
|
|
353
331
|
if dims is None:
|
|
354
332
|
dims = []
|
|
@@ -366,15 +344,13 @@ class MLP(nn.Module):
|
|
|
366
344
|
def forward(self, x):
|
|
367
345
|
return self.mlp(x)
|
|
368
346
|
|
|
369
|
-
|
|
370
347
|
class FM(nn.Module):
|
|
371
348
|
"""Factorization Machine (Rendle, 2010) second-order interaction term."""
|
|
372
|
-
|
|
373
|
-
def __init__(self, reduce_sum=True):
|
|
349
|
+
def __init__(self, reduce_sum: bool = True):
|
|
374
350
|
super().__init__()
|
|
375
351
|
self.reduce_sum = reduce_sum
|
|
376
352
|
|
|
377
|
-
def forward(self, x):
|
|
353
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
378
354
|
square_of_sum = torch.sum(x, dim=1)**2
|
|
379
355
|
sum_of_square = torch.sum(x**2, dim=1)
|
|
380
356
|
ix = square_of_sum - sum_of_square
|
|
@@ -382,157 +358,30 @@ class FM(nn.Module):
|
|
|
382
358
|
ix = torch.sum(ix, dim=1, keepdim=True)
|
|
383
359
|
return 0.5 * ix
|
|
384
360
|
|
|
385
|
-
|
|
386
|
-
class CIN(nn.Module):
|
|
387
|
-
"""Compressed Interaction Network from xDeepFM (Lian et al., 2018)."""
|
|
388
|
-
|
|
389
|
-
def __init__(self, input_dim, cin_size, split_half=True):
|
|
390
|
-
super().__init__()
|
|
391
|
-
self.num_layers = len(cin_size)
|
|
392
|
-
self.split_half = split_half
|
|
393
|
-
self.conv_layers = torch.nn.ModuleList()
|
|
394
|
-
prev_dim, fc_input_dim = input_dim, 0
|
|
395
|
-
for i in range(self.num_layers):
|
|
396
|
-
cross_layer_size = cin_size[i]
|
|
397
|
-
self.conv_layers.append(torch.nn.Conv1d(input_dim * prev_dim, cross_layer_size, 1, stride=1, dilation=1, bias=True))
|
|
398
|
-
if self.split_half and i != self.num_layers - 1:
|
|
399
|
-
cross_layer_size //= 2
|
|
400
|
-
prev_dim = cross_layer_size
|
|
401
|
-
fc_input_dim += prev_dim
|
|
402
|
-
self.fc = torch.nn.Linear(fc_input_dim, 1)
|
|
403
|
-
|
|
404
|
-
def forward(self, x):
|
|
405
|
-
xs = list()
|
|
406
|
-
x0, h = x.unsqueeze(2), x
|
|
407
|
-
for i in range(self.num_layers):
|
|
408
|
-
x = x0 * h.unsqueeze(1)
|
|
409
|
-
batch_size, f0_dim, fin_dim, embed_dim = x.shape
|
|
410
|
-
x = x.view(batch_size, f0_dim * fin_dim, embed_dim)
|
|
411
|
-
x = F.relu(self.conv_layers[i](x))
|
|
412
|
-
if self.split_half and i != self.num_layers - 1:
|
|
413
|
-
x, h = torch.split(x, x.shape[1] // 2, dim=1)
|
|
414
|
-
else:
|
|
415
|
-
h = x
|
|
416
|
-
xs.append(x)
|
|
417
|
-
return self.fc(torch.sum(torch.cat(xs, dim=1), 2))
|
|
418
|
-
|
|
419
361
|
class CrossLayer(nn.Module):
|
|
420
362
|
"""Single cross layer used in DCN (Wang et al., 2017)."""
|
|
421
|
-
|
|
422
|
-
def __init__(self, input_dim):
|
|
363
|
+
def __init__(self, input_dim: int):
|
|
423
364
|
super(CrossLayer, self).__init__()
|
|
424
365
|
self.w = torch.nn.Linear(input_dim, 1, bias=False)
|
|
425
366
|
self.b = torch.nn.Parameter(torch.zeros(input_dim))
|
|
426
367
|
|
|
427
|
-
def forward(self, x_0, x_i):
|
|
368
|
+
def forward(self, x_0: torch.Tensor, x_i: torch.Tensor) -> torch.Tensor:
|
|
428
369
|
x = self.w(x_i) * x_0 + self.b
|
|
429
370
|
return x
|
|
430
371
|
|
|
431
|
-
|
|
432
|
-
class CrossNetwork(nn.Module):
|
|
433
|
-
"""Stacked Cross Layers from DCN (Wang et al., 2017)."""
|
|
434
|
-
|
|
435
|
-
def __init__(self, input_dim, num_layers):
|
|
436
|
-
super().__init__()
|
|
437
|
-
self.num_layers = num_layers
|
|
438
|
-
self.w = torch.nn.ModuleList([torch.nn.Linear(input_dim, 1, bias=False) for _ in range(num_layers)])
|
|
439
|
-
self.b = torch.nn.ParameterList([torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)])
|
|
440
|
-
|
|
441
|
-
def forward(self, x):
|
|
442
|
-
"""
|
|
443
|
-
:param x: Float tensor of size ``(batch_size, num_fields, embed_dim)``
|
|
444
|
-
"""
|
|
445
|
-
x0 = x
|
|
446
|
-
for i in range(self.num_layers):
|
|
447
|
-
xw = self.w[i](x)
|
|
448
|
-
x = x0 * xw + self.b[i] + x
|
|
449
|
-
return x
|
|
450
|
-
|
|
451
|
-
class CrossNetV2(nn.Module):
|
|
452
|
-
"""Vector-wise cross network proposed in DCN V2 (Wang et al., 2021)."""
|
|
453
|
-
def __init__(self, input_dim, num_layers):
|
|
454
|
-
super().__init__()
|
|
455
|
-
self.num_layers = num_layers
|
|
456
|
-
self.w = torch.nn.ModuleList([torch.nn.Linear(input_dim, input_dim, bias=False) for _ in range(num_layers)])
|
|
457
|
-
self.b = torch.nn.ParameterList([torch.nn.Parameter(torch.zeros((input_dim,))) for _ in range(num_layers)])
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
def forward(self, x):
|
|
461
|
-
x0 = x
|
|
462
|
-
for i in range(self.num_layers):
|
|
463
|
-
x =x0*self.w[i](x) + self.b[i] + x
|
|
464
|
-
return x
|
|
465
|
-
|
|
466
|
-
class CrossNetMix(nn.Module):
|
|
467
|
-
"""Mixture of low-rank cross experts from DCN V2 (Wang et al., 2021)."""
|
|
468
|
-
|
|
469
|
-
def __init__(self, input_dim, num_layers=2, low_rank=32, num_experts=4):
|
|
470
|
-
super(CrossNetMix, self).__init__()
|
|
471
|
-
self.num_layers = num_layers
|
|
472
|
-
self.num_experts = num_experts
|
|
473
|
-
|
|
474
|
-
# U: (input_dim, low_rank)
|
|
475
|
-
self.u_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
|
|
476
|
-
torch.empty(num_experts, input_dim, low_rank))) for i in range(self.num_layers)])
|
|
477
|
-
# V: (input_dim, low_rank)
|
|
478
|
-
self.v_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
|
|
479
|
-
torch.empty(num_experts, input_dim, low_rank))) for i in range(self.num_layers)])
|
|
480
|
-
# C: (low_rank, low_rank)
|
|
481
|
-
self.c_list = torch.nn.ParameterList([nn.Parameter(nn.init.xavier_normal_(
|
|
482
|
-
torch.empty(num_experts, low_rank, low_rank))) for i in range(self.num_layers)])
|
|
483
|
-
self.gating = nn.ModuleList([nn.Linear(input_dim, 1, bias=False) for i in range(self.num_experts)])
|
|
484
|
-
|
|
485
|
-
self.bias = torch.nn.ParameterList([nn.Parameter(nn.init.zeros_(
|
|
486
|
-
torch.empty(input_dim, 1))) for i in range(self.num_layers)])
|
|
487
|
-
|
|
488
|
-
def forward(self, x):
|
|
489
|
-
x_0 = x.unsqueeze(2) # (bs, in_features, 1)
|
|
490
|
-
x_l = x_0
|
|
491
|
-
for i in range(self.num_layers):
|
|
492
|
-
output_of_experts = []
|
|
493
|
-
gating_score_experts = []
|
|
494
|
-
for expert_id in range(self.num_experts):
|
|
495
|
-
# (1) G(x_l)
|
|
496
|
-
# compute the gating score by x_l
|
|
497
|
-
gating_score_experts.append(self.gating[expert_id](x_l.squeeze(2)))
|
|
498
|
-
|
|
499
|
-
# (2) E(x_l)
|
|
500
|
-
# project the input x_l to $\mathbb{R}^{r}$
|
|
501
|
-
v_x = torch.matmul(self.v_list[i][expert_id].t(), x_l) # (bs, low_rank, 1)
|
|
502
|
-
|
|
503
|
-
# nonlinear activation in low rank space
|
|
504
|
-
v_x = torch.tanh(v_x)
|
|
505
|
-
v_x = torch.matmul(self.c_list[i][expert_id], v_x)
|
|
506
|
-
v_x = torch.tanh(v_x)
|
|
507
|
-
|
|
508
|
-
# project back to $\mathbb{R}^{d}$
|
|
509
|
-
uv_x = torch.matmul(self.u_list[i][expert_id], v_x) # (bs, in_features, 1)
|
|
510
|
-
|
|
511
|
-
dot_ = uv_x + self.bias[i]
|
|
512
|
-
dot_ = x_0 * dot_ # Hadamard-product
|
|
513
|
-
|
|
514
|
-
output_of_experts.append(dot_.squeeze(2))
|
|
515
|
-
|
|
516
|
-
# (3) mixture of low-rank experts
|
|
517
|
-
output_of_experts = torch.stack(output_of_experts, 2) # (bs, in_features, num_experts)
|
|
518
|
-
gating_score_experts = torch.stack(gating_score_experts, 1) # (bs, num_experts, 1)
|
|
519
|
-
moe_out = torch.matmul(output_of_experts, gating_score_experts.softmax(1))
|
|
520
|
-
x_l = moe_out + x_l # (bs, in_features, 1)
|
|
521
|
-
|
|
522
|
-
x_l = x_l.squeeze() # (bs, in_features)
|
|
523
|
-
return x_l
|
|
524
|
-
|
|
525
372
|
class SENETLayer(nn.Module):
|
|
526
373
|
"""Squeeze-and-Excitation block adopted by FiBiNET (Huang et al., 2019)."""
|
|
527
|
-
|
|
528
|
-
|
|
374
|
+
def __init__(
|
|
375
|
+
self,
|
|
376
|
+
num_fields: int,
|
|
377
|
+
reduction_ratio: int = 3):
|
|
529
378
|
super(SENETLayer, self).__init__()
|
|
530
379
|
reduced_size = max(1, int(num_fields/ reduction_ratio))
|
|
531
380
|
self.mlp = nn.Sequential(nn.Linear(num_fields, reduced_size, bias=False),
|
|
532
381
|
nn.ReLU(),
|
|
533
382
|
nn.Linear(reduced_size, num_fields, bias=False),
|
|
534
383
|
nn.ReLU())
|
|
535
|
-
def forward(self, x):
|
|
384
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
536
385
|
z = torch.mean(x, dim=-1, out=None)
|
|
537
386
|
a = self.mlp(z)
|
|
538
387
|
v = x*a.unsqueeze(-1)
|
|
@@ -540,8 +389,11 @@ class SENETLayer(nn.Module):
|
|
|
540
389
|
|
|
541
390
|
class BiLinearInteractionLayer(nn.Module):
|
|
542
391
|
"""Bilinear feature interaction from FiBiNET (Huang et al., 2019)."""
|
|
543
|
-
|
|
544
|
-
|
|
392
|
+
def __init__(
|
|
393
|
+
self,
|
|
394
|
+
input_dim: int,
|
|
395
|
+
num_fields: int,
|
|
396
|
+
bilinear_type: str = "field_interaction"):
|
|
545
397
|
super(BiLinearInteractionLayer, self).__init__()
|
|
546
398
|
self.bilinear_type = bilinear_type
|
|
547
399
|
if self.bilinear_type == "field_all":
|
|
@@ -553,263 +405,96 @@ class BiLinearInteractionLayer(nn.Module):
|
|
|
553
405
|
else:
|
|
554
406
|
raise NotImplementedError()
|
|
555
407
|
|
|
556
|
-
def forward(self, x):
|
|
408
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
557
409
|
feature_emb = torch.split(x, 1, dim=1)
|
|
558
410
|
if self.bilinear_type == "field_all":
|
|
559
411
|
bilinear_list = [self.bilinear_layer(v_i)*v_j for v_i, v_j in combinations(feature_emb, 2)]
|
|
560
412
|
elif self.bilinear_type == "field_each":
|
|
561
|
-
bilinear_list = [self.bilinear_layer[i](feature_emb[i])*feature_emb[j] for i,j in combinations(range(len(feature_emb)), 2)]
|
|
413
|
+
bilinear_list = [self.bilinear_layer[i](feature_emb[i])*feature_emb[j] for i,j in combinations(range(len(feature_emb)), 2)] # type: ignore[assignment]
|
|
562
414
|
elif self.bilinear_type == "field_interaction":
|
|
563
|
-
bilinear_list = [self.bilinear_layer[i](v[0])*v[1] for i,v in enumerate(combinations(feature_emb, 2))]
|
|
415
|
+
bilinear_list = [self.bilinear_layer[i](v[0])*v[1] for i,v in enumerate(combinations(feature_emb, 2))] # type: ignore[assignment]
|
|
564
416
|
return torch.cat(bilinear_list, dim=1)
|
|
565
417
|
|
|
566
|
-
|
|
567
|
-
class MultiInterestSA(nn.Module):
|
|
568
|
-
"""Multi-interest self-attention extractor from MIND (Li et al., 2019)."""
|
|
569
|
-
|
|
570
|
-
def __init__(self, embedding_dim, interest_num, hidden_dim=None):
|
|
571
|
-
super(MultiInterestSA, self).__init__()
|
|
572
|
-
self.embedding_dim = embedding_dim
|
|
573
|
-
self.interest_num = interest_num
|
|
574
|
-
if hidden_dim == None:
|
|
575
|
-
self.hidden_dim = self.embedding_dim * 4
|
|
576
|
-
self.W1 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.hidden_dim), requires_grad=True)
|
|
577
|
-
self.W2 = torch.nn.Parameter(torch.rand(self.hidden_dim, self.interest_num), requires_grad=True)
|
|
578
|
-
self.W3 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.embedding_dim), requires_grad=True)
|
|
579
|
-
|
|
580
|
-
def forward(self, seq_emb, mask=None):
|
|
581
|
-
H = torch.einsum('bse, ed -> bsd', seq_emb, self.W1).tanh()
|
|
582
|
-
if mask != None:
|
|
583
|
-
A = torch.einsum('bsd, dk -> bsk', H, self.W2) + -1.e9 * (1 - mask.float())
|
|
584
|
-
A = F.softmax(A, dim=1)
|
|
585
|
-
else:
|
|
586
|
-
A = F.softmax(torch.einsum('bsd, dk -> bsk', H, self.W2), dim=1)
|
|
587
|
-
A = A.permute(0, 2, 1)
|
|
588
|
-
multi_interest_emb = torch.matmul(A, seq_emb)
|
|
589
|
-
return multi_interest_emb
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
class CapsuleNetwork(nn.Module):
|
|
593
|
-
"""Dynamic routing capsule network used in MIND (Li et al., 2019)."""
|
|
594
|
-
|
|
595
|
-
def __init__(self, embedding_dim, seq_len, bilinear_type=2, interest_num=4, routing_times=3, relu_layer=False):
|
|
596
|
-
super(CapsuleNetwork, self).__init__()
|
|
597
|
-
self.embedding_dim = embedding_dim # h
|
|
598
|
-
self.seq_len = seq_len # s
|
|
599
|
-
self.bilinear_type = bilinear_type
|
|
600
|
-
self.interest_num = interest_num
|
|
601
|
-
self.routing_times = routing_times
|
|
602
|
-
|
|
603
|
-
self.relu_layer = relu_layer
|
|
604
|
-
self.stop_grad = True
|
|
605
|
-
self.relu = nn.Sequential(nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), nn.ReLU())
|
|
606
|
-
if self.bilinear_type == 0: # MIND
|
|
607
|
-
self.linear = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
|
|
608
|
-
elif self.bilinear_type == 1:
|
|
609
|
-
self.linear = nn.Linear(self.embedding_dim, self.embedding_dim * self.interest_num, bias=False)
|
|
610
|
-
else:
|
|
611
|
-
self.w = nn.Parameter(torch.Tensor(1, self.seq_len, self.interest_num * self.embedding_dim, self.embedding_dim))
|
|
612
|
-
nn.init.xavier_uniform_(self.w)
|
|
613
|
-
|
|
614
|
-
def forward(self, item_eb, mask):
|
|
615
|
-
if self.bilinear_type == 0:
|
|
616
|
-
item_eb_hat = self.linear(item_eb)
|
|
617
|
-
item_eb_hat = item_eb_hat.repeat(1, 1, self.interest_num)
|
|
618
|
-
elif self.bilinear_type == 1:
|
|
619
|
-
item_eb_hat = self.linear(item_eb)
|
|
620
|
-
else:
|
|
621
|
-
u = torch.unsqueeze(item_eb, dim=2)
|
|
622
|
-
item_eb_hat = torch.sum(self.w[:, :self.seq_len, :, :] * u, dim=3)
|
|
623
|
-
|
|
624
|
-
item_eb_hat = torch.reshape(item_eb_hat, (-1, self.seq_len, self.interest_num, self.embedding_dim))
|
|
625
|
-
item_eb_hat = torch.transpose(item_eb_hat, 1, 2).contiguous()
|
|
626
|
-
item_eb_hat = torch.reshape(item_eb_hat, (-1, self.interest_num, self.seq_len, self.embedding_dim))
|
|
627
|
-
|
|
628
|
-
if self.stop_grad:
|
|
629
|
-
item_eb_hat_iter = item_eb_hat.detach()
|
|
630
|
-
else:
|
|
631
|
-
item_eb_hat_iter = item_eb_hat
|
|
632
|
-
|
|
633
|
-
if self.bilinear_type > 0:
|
|
634
|
-
capsule_weight = torch.zeros(item_eb_hat.shape[0],
|
|
635
|
-
self.interest_num,
|
|
636
|
-
self.seq_len,
|
|
637
|
-
device=item_eb.device,
|
|
638
|
-
requires_grad=False)
|
|
639
|
-
else:
|
|
640
|
-
capsule_weight = torch.randn(item_eb_hat.shape[0],
|
|
641
|
-
self.interest_num,
|
|
642
|
-
self.seq_len,
|
|
643
|
-
device=item_eb.device,
|
|
644
|
-
requires_grad=False)
|
|
645
|
-
|
|
646
|
-
for i in range(self.routing_times): # 动态路由传播3次
|
|
647
|
-
atten_mask = torch.unsqueeze(mask, 1).repeat(1, self.interest_num, 1)
|
|
648
|
-
paddings = torch.zeros_like(atten_mask, dtype=torch.float)
|
|
649
|
-
|
|
650
|
-
capsule_softmax_weight = F.softmax(capsule_weight, dim=-1)
|
|
651
|
-
capsule_softmax_weight = torch.where(torch.eq(atten_mask, 0), paddings, capsule_softmax_weight)
|
|
652
|
-
capsule_softmax_weight = torch.unsqueeze(capsule_softmax_weight, 2)
|
|
653
|
-
|
|
654
|
-
if i < 2:
|
|
655
|
-
interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat_iter)
|
|
656
|
-
cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
|
|
657
|
-
scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
|
|
658
|
-
interest_capsule = scalar_factor * interest_capsule
|
|
659
|
-
|
|
660
|
-
delta_weight = torch.matmul(item_eb_hat_iter, torch.transpose(interest_capsule, 2, 3).contiguous())
|
|
661
|
-
delta_weight = torch.reshape(delta_weight, (-1, self.interest_num, self.seq_len))
|
|
662
|
-
capsule_weight = capsule_weight + delta_weight
|
|
663
|
-
else:
|
|
664
|
-
interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat)
|
|
665
|
-
cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
|
|
666
|
-
scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
|
|
667
|
-
interest_capsule = scalar_factor * interest_capsule
|
|
668
|
-
|
|
669
|
-
interest_capsule = torch.reshape(interest_capsule, (-1, self.interest_num, self.embedding_dim))
|
|
670
|
-
|
|
671
|
-
if self.relu_layer:
|
|
672
|
-
interest_capsule = self.relu(interest_capsule)
|
|
673
|
-
|
|
674
|
-
return interest_capsule
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
class FFM(nn.Module):
|
|
678
|
-
"""Field-aware Factorization Machine (Juan et al., 2016)."""
|
|
679
|
-
|
|
680
|
-
def __init__(self, num_fields, reduce_sum=True):
|
|
681
|
-
super().__init__()
|
|
682
|
-
self.num_fields = num_fields
|
|
683
|
-
self.reduce_sum = reduce_sum
|
|
684
|
-
|
|
685
|
-
def forward(self, x):
|
|
686
|
-
# compute (non-redundant) second order field-aware feature crossings
|
|
687
|
-
crossed_embeddings = []
|
|
688
|
-
for i in range(self.num_fields-1):
|
|
689
|
-
for j in range(i+1, self.num_fields):
|
|
690
|
-
crossed_embeddings.append(x[:, i, j, :] * x[:, j, i, :])
|
|
691
|
-
crossed_embeddings = torch.stack(crossed_embeddings, dim=1)
|
|
692
|
-
|
|
693
|
-
# if reduce_sum is true, the crossing operation is effectively inner product, other wise Hadamard-product
|
|
694
|
-
if self.reduce_sum:
|
|
695
|
-
crossed_embeddings = torch.sum(crossed_embeddings, dim=-1, keepdim=True)
|
|
696
|
-
return crossed_embeddings
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
class CEN(nn.Module):
|
|
700
|
-
"""Field-attentive interaction network from FAT-DeepFFM (Wang et al., 2020)."""
|
|
701
|
-
|
|
702
|
-
def __init__(self, embed_dim, num_field_crosses, reduction_ratio):
|
|
703
|
-
super().__init__()
|
|
704
|
-
|
|
705
|
-
# convolution weight (Eq.7 FAT-DeepFFM)
|
|
706
|
-
self.u = torch.nn.Parameter(torch.rand(num_field_crosses, embed_dim), requires_grad=True)
|
|
707
|
-
|
|
708
|
-
# two FC layers that computes the field attention
|
|
709
|
-
self.mlp_att = MLP(num_field_crosses, dims=[num_field_crosses//reduction_ratio, num_field_crosses], output_layer=False, activation="relu")
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
def forward(self, em):
|
|
713
|
-
# compute descriptor vector (Eq.7 FAT-DeepFFM), output shape [batch_size, num_field_crosses]
|
|
714
|
-
d = F.relu((self.u.squeeze(0) * em).sum(-1))
|
|
715
|
-
|
|
716
|
-
# compute field attention (Eq.9), output shape [batch_size, num_field_crosses]
|
|
717
|
-
s = self.mlp_att(d)
|
|
718
|
-
|
|
719
|
-
# rescale original embedding with field attention (Eq.10), output shape [batch_size, num_field_crosses, embed_dim]
|
|
720
|
-
aem = s.unsqueeze(-1) * em
|
|
721
|
-
return aem.flatten(start_dim=1)
|
|
722
|
-
|
|
723
|
-
|
|
724
418
|
class MultiHeadSelfAttention(nn.Module):
|
|
725
419
|
"""Multi-head self-attention layer from AutoInt (Song et al., 2019)."""
|
|
726
|
-
|
|
727
|
-
|
|
420
|
+
def __init__(
|
|
421
|
+
self,
|
|
422
|
+
embedding_dim: int,
|
|
423
|
+
num_heads: int = 2,
|
|
424
|
+
dropout: float = 0.0,
|
|
425
|
+
use_residual: bool = True):
|
|
728
426
|
super().__init__()
|
|
729
427
|
if embedding_dim % num_heads != 0:
|
|
730
|
-
raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by num_heads ({num_heads})")
|
|
731
|
-
|
|
428
|
+
raise ValueError(f"[MultiHeadSelfAttention Error]: embedding_dim ({embedding_dim}) must be divisible by num_heads ({num_heads})")
|
|
732
429
|
self.embedding_dim = embedding_dim
|
|
733
430
|
self.num_heads = num_heads
|
|
734
431
|
self.head_dim = embedding_dim // num_heads
|
|
735
432
|
self.use_residual = use_residual
|
|
736
|
-
|
|
737
433
|
self.W_Q = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
738
434
|
self.W_K = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
739
435
|
self.W_V = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
740
|
-
|
|
741
436
|
if self.use_residual:
|
|
742
437
|
self.W_Res = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
743
|
-
|
|
744
438
|
self.dropout = nn.Dropout(dropout)
|
|
745
439
|
|
|
746
|
-
def forward(self, x):
|
|
440
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
747
441
|
"""
|
|
748
442
|
Args:
|
|
749
|
-
x:
|
|
443
|
+
x (torch.Tensor): Tensor of shape (batch_size, num_fields, embedding_dim)
|
|
444
|
+
|
|
750
445
|
Returns:
|
|
751
|
-
|
|
446
|
+
torch.Tensor: Output tensor of shape (batch_size, num_fields, embedding_dim)
|
|
752
447
|
"""
|
|
753
448
|
batch_size, num_fields, _ = x.shape
|
|
754
|
-
|
|
755
|
-
# Linear projections
|
|
756
449
|
Q = self.W_Q(x) # [batch_size, num_fields, embedding_dim]
|
|
757
450
|
K = self.W_K(x)
|
|
758
451
|
V = self.W_V(x)
|
|
759
|
-
|
|
760
452
|
# Split into multiple heads: [batch_size, num_heads, num_fields, head_dim]
|
|
761
453
|
Q = Q.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
|
|
762
454
|
K = K.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
|
|
763
455
|
V = V.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
|
|
764
|
-
|
|
765
456
|
# Attention scores
|
|
766
457
|
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
|
|
767
458
|
attention_weights = F.softmax(scores, dim=-1)
|
|
768
459
|
attention_weights = self.dropout(attention_weights)
|
|
769
|
-
|
|
770
|
-
# Apply attention to values
|
|
771
460
|
attention_output = torch.matmul(attention_weights, V) # [batch_size, num_heads, num_fields, head_dim]
|
|
772
|
-
|
|
773
461
|
# Concatenate heads
|
|
774
462
|
attention_output = attention_output.transpose(1, 2).contiguous()
|
|
775
463
|
attention_output = attention_output.view(batch_size, num_fields, self.embedding_dim)
|
|
776
|
-
|
|
777
464
|
# Residual connection
|
|
778
465
|
if self.use_residual:
|
|
779
466
|
output = attention_output + self.W_Res(x)
|
|
780
467
|
else:
|
|
781
468
|
output = attention_output
|
|
782
|
-
|
|
783
469
|
output = F.relu(output)
|
|
784
|
-
|
|
785
470
|
return output
|
|
786
471
|
|
|
787
|
-
|
|
788
472
|
class AttentionPoolingLayer(nn.Module):
|
|
789
473
|
"""
|
|
790
474
|
Attention pooling layer for DIN/DIEN
|
|
791
475
|
Computes attention weights between query (candidate item) and keys (user behavior sequence)
|
|
792
476
|
"""
|
|
793
|
-
|
|
794
|
-
|
|
477
|
+
def __init__(
|
|
478
|
+
self,
|
|
479
|
+
embedding_dim: int,
|
|
480
|
+
hidden_units: list = [80, 40],
|
|
481
|
+
activation: str ='sigmoid',
|
|
482
|
+
use_softmax: bool = True):
|
|
795
483
|
super().__init__()
|
|
796
484
|
self.embedding_dim = embedding_dim
|
|
797
485
|
self.use_softmax = use_softmax
|
|
798
|
-
|
|
799
486
|
# Build attention network
|
|
800
487
|
# Input: [query, key, query-key, query*key] -> 4 * embedding_dim
|
|
801
488
|
input_dim = 4 * embedding_dim
|
|
802
489
|
layers = []
|
|
803
|
-
|
|
804
490
|
for hidden_unit in hidden_units:
|
|
805
491
|
layers.append(nn.Linear(input_dim, hidden_unit))
|
|
806
492
|
layers.append(activation_layer(activation))
|
|
807
493
|
input_dim = hidden_unit
|
|
808
|
-
|
|
809
494
|
layers.append(nn.Linear(input_dim, 1))
|
|
810
495
|
self.attention_net = nn.Sequential(*layers)
|
|
811
496
|
|
|
812
|
-
def forward(self, query, keys, keys_length=None, mask=None):
|
|
497
|
+
def forward(self, query: torch.Tensor, keys: torch.Tensor, keys_length: torch.Tensor | None = None, mask: torch.Tensor | None = None):
|
|
813
498
|
"""
|
|
814
499
|
Args:
|
|
815
500
|
query: [batch_size, embedding_dim] - candidate item embedding
|
|
@@ -819,162 +504,40 @@ class AttentionPoolingLayer(nn.Module):
|
|
|
819
504
|
Returns:
|
|
820
505
|
output: [batch_size, embedding_dim] - attention pooled representation
|
|
821
506
|
"""
|
|
822
|
-
batch_size,
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
507
|
+
batch_size, sequence_length, embedding_dim = keys.shape
|
|
508
|
+
assert query.shape == (batch_size, embedding_dim), f"query shape {query.shape} != ({batch_size}, {embedding_dim})"
|
|
509
|
+
if mask is None and keys_length is not None:
|
|
510
|
+
# keys_length: (batch_size,)
|
|
511
|
+
device = keys.device
|
|
512
|
+
seq_range = torch.arange(sequence_length, device=device).unsqueeze(0) # (1, sequence_length)
|
|
513
|
+
mask = (seq_range < keys_length.unsqueeze(1)).unsqueeze(-1).float()
|
|
514
|
+
if mask is not None:
|
|
515
|
+
if mask.dim() == 2:
|
|
516
|
+
# (B, L)
|
|
517
|
+
mask = mask.unsqueeze(-1)
|
|
518
|
+
elif mask.dim() == 3 and mask.shape[1] == 1 and mask.shape[2] == sequence_length:
|
|
519
|
+
# (B, 1, L) -> (B, L, 1)
|
|
520
|
+
mask = mask.transpose(1, 2)
|
|
521
|
+
elif mask.dim() == 3 and mask.shape[1] == sequence_length and mask.shape[2] == 1:
|
|
522
|
+
pass
|
|
523
|
+
else:
|
|
524
|
+
raise ValueError(f"[AttentionPoolingLayer Error]: Unsupported mask shape: {mask.shape}")
|
|
525
|
+
mask = mask.to(keys.dtype)
|
|
526
|
+
# Expand query to (B, L, D)
|
|
527
|
+
query_expanded = query.unsqueeze(1).expand(-1, sequence_length, -1)
|
|
528
|
+
# [query, key, query-key, query*key] -> (B, L, 4D)
|
|
529
|
+
attention_input = torch.cat([query_expanded, keys, query_expanded - keys, query_expanded * keys], dim=-1,)
|
|
530
|
+
attention_scores = self.attention_net(attention_input)
|
|
839
531
|
if mask is not None:
|
|
840
532
|
attention_scores = attention_scores.masked_fill(mask == 0, -1e9)
|
|
841
|
-
|
|
842
|
-
# Apply softmax to get attention weights
|
|
533
|
+
# Get attention weights
|
|
843
534
|
if self.use_softmax:
|
|
844
|
-
|
|
535
|
+
# softmax over seq_len
|
|
536
|
+
attention_weights = F.softmax(attention_scores, dim=1) # (B, L, 1)
|
|
845
537
|
else:
|
|
846
|
-
attention_weights = attention_scores
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
538
|
+
attention_weights = torch.sigmoid(attention_scores)
|
|
539
|
+
if mask is not None:
|
|
540
|
+
attention_weights = attention_weights * mask
|
|
541
|
+
# Weighted sum over keys: (B, L, 1) * (B, L, D) -> (B, D)
|
|
542
|
+
output = torch.sum(attention_weights * keys, dim=1)
|
|
851
543
|
return output
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
class DynamicGRU(nn.Module):
|
|
855
|
-
"""Dynamic GRU unit with auxiliary loss path from DIEN (Zhou et al., 2019)."""
|
|
856
|
-
"""
|
|
857
|
-
GRU with dynamic routing for DIEN
|
|
858
|
-
"""
|
|
859
|
-
|
|
860
|
-
def __init__(self, input_size, hidden_size, bias=True):
|
|
861
|
-
super().__init__()
|
|
862
|
-
self.input_size = input_size
|
|
863
|
-
self.hidden_size = hidden_size
|
|
864
|
-
|
|
865
|
-
# GRU parameters
|
|
866
|
-
self.weight_ih = nn.Parameter(torch.randn(3 * hidden_size, input_size))
|
|
867
|
-
self.weight_hh = nn.Parameter(torch.randn(3 * hidden_size, hidden_size))
|
|
868
|
-
if bias:
|
|
869
|
-
self.bias_ih = nn.Parameter(torch.randn(3 * hidden_size))
|
|
870
|
-
self.bias_hh = nn.Parameter(torch.randn(3 * hidden_size))
|
|
871
|
-
else:
|
|
872
|
-
self.register_parameter('bias_ih', None)
|
|
873
|
-
self.register_parameter('bias_hh', None)
|
|
874
|
-
|
|
875
|
-
self.reset_parameters()
|
|
876
|
-
|
|
877
|
-
def reset_parameters(self):
|
|
878
|
-
std = 1.0 / (self.hidden_size) ** 0.5
|
|
879
|
-
for weight in self.parameters():
|
|
880
|
-
weight.data.uniform_(-std, std)
|
|
881
|
-
|
|
882
|
-
def forward(self, x, att_scores=None):
|
|
883
|
-
"""
|
|
884
|
-
Args:
|
|
885
|
-
x: [batch_size, seq_len, input_size]
|
|
886
|
-
att_scores: [batch_size, seq_len] - attention scores for auxiliary loss
|
|
887
|
-
Returns:
|
|
888
|
-
output: [batch_size, seq_len, hidden_size]
|
|
889
|
-
hidden: [batch_size, hidden_size] - final hidden state
|
|
890
|
-
"""
|
|
891
|
-
batch_size, seq_len, _ = x.shape
|
|
892
|
-
|
|
893
|
-
# Initialize hidden state
|
|
894
|
-
h = torch.zeros(batch_size, self.hidden_size, device=x.device)
|
|
895
|
-
|
|
896
|
-
outputs = []
|
|
897
|
-
for t in range(seq_len):
|
|
898
|
-
x_t = x[:, t, :] # [batch_size, input_size]
|
|
899
|
-
|
|
900
|
-
# GRU computation
|
|
901
|
-
gi = F.linear(x_t, self.weight_ih, self.bias_ih)
|
|
902
|
-
gh = F.linear(h, self.weight_hh, self.bias_hh)
|
|
903
|
-
i_r, i_i, i_n = gi.chunk(3, 1)
|
|
904
|
-
h_r, h_i, h_n = gh.chunk(3, 1)
|
|
905
|
-
|
|
906
|
-
resetgate = torch.sigmoid(i_r + h_r)
|
|
907
|
-
inputgate = torch.sigmoid(i_i + h_i)
|
|
908
|
-
newgate = torch.tanh(i_n + resetgate * h_n)
|
|
909
|
-
h = newgate + inputgate * (h - newgate)
|
|
910
|
-
|
|
911
|
-
outputs.append(h.unsqueeze(1))
|
|
912
|
-
|
|
913
|
-
output = torch.cat(outputs, dim=1) # [batch_size, seq_len, hidden_size]
|
|
914
|
-
|
|
915
|
-
return output, h
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
class AUGRU(nn.Module):
|
|
919
|
-
"""Attention-aware GRU update gate used in DIEN (Zhou et al., 2019)."""
|
|
920
|
-
"""
|
|
921
|
-
Attention-based GRU for DIEN
|
|
922
|
-
Uses attention scores to weight the update of hidden states
|
|
923
|
-
"""
|
|
924
|
-
|
|
925
|
-
def __init__(self, input_size, hidden_size, bias=True):
|
|
926
|
-
super().__init__()
|
|
927
|
-
self.input_size = input_size
|
|
928
|
-
self.hidden_size = hidden_size
|
|
929
|
-
|
|
930
|
-
self.weight_ih = nn.Parameter(torch.randn(3 * hidden_size, input_size))
|
|
931
|
-
self.weight_hh = nn.Parameter(torch.randn(3 * hidden_size, hidden_size))
|
|
932
|
-
if bias:
|
|
933
|
-
self.bias_ih = nn.Parameter(torch.randn(3 * hidden_size))
|
|
934
|
-
self.bias_hh = nn.Parameter(torch.randn(3 * hidden_size))
|
|
935
|
-
else:
|
|
936
|
-
self.register_parameter('bias_ih', None)
|
|
937
|
-
self.register_parameter('bias_hh', None)
|
|
938
|
-
|
|
939
|
-
self.reset_parameters()
|
|
940
|
-
|
|
941
|
-
def reset_parameters(self):
|
|
942
|
-
std = 1.0 / (self.hidden_size) ** 0.5
|
|
943
|
-
for weight in self.parameters():
|
|
944
|
-
weight.data.uniform_(-std, std)
|
|
945
|
-
|
|
946
|
-
def forward(self, x, att_scores):
|
|
947
|
-
"""
|
|
948
|
-
Args:
|
|
949
|
-
x: [batch_size, seq_len, input_size]
|
|
950
|
-
att_scores: [batch_size, seq_len, 1] - attention scores
|
|
951
|
-
Returns:
|
|
952
|
-
output: [batch_size, seq_len, hidden_size]
|
|
953
|
-
hidden: [batch_size, hidden_size] - final hidden state
|
|
954
|
-
"""
|
|
955
|
-
batch_size, seq_len, _ = x.shape
|
|
956
|
-
|
|
957
|
-
h = torch.zeros(batch_size, self.hidden_size, device=x.device)
|
|
958
|
-
|
|
959
|
-
outputs = []
|
|
960
|
-
for t in range(seq_len):
|
|
961
|
-
x_t = x[:, t, :] # [batch_size, input_size]
|
|
962
|
-
att_t = att_scores[:, t, :] # [batch_size, 1]
|
|
963
|
-
|
|
964
|
-
gi = F.linear(x_t, self.weight_ih, self.bias_ih)
|
|
965
|
-
gh = F.linear(h, self.weight_hh, self.bias_hh)
|
|
966
|
-
i_r, i_i, i_n = gi.chunk(3, 1)
|
|
967
|
-
h_r, h_i, h_n = gh.chunk(3, 1)
|
|
968
|
-
|
|
969
|
-
resetgate = torch.sigmoid(i_r + h_r)
|
|
970
|
-
inputgate = torch.sigmoid(i_i + h_i)
|
|
971
|
-
newgate = torch.tanh(i_n + resetgate * h_n)
|
|
972
|
-
|
|
973
|
-
# Use attention score to control update
|
|
974
|
-
h = (1 - att_t) * h + att_t * newgate
|
|
975
|
-
|
|
976
|
-
outputs.append(h.unsqueeze(1))
|
|
977
|
-
|
|
978
|
-
output = torch.cat(outputs, dim=1)
|
|
979
|
-
|
|
980
|
-
return output, h
|