nextrec 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +1 -1
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -5
- nextrec/basic/callback.py +1 -0
- nextrec/basic/features.py +30 -22
- nextrec/basic/layers.py +220 -106
- nextrec/basic/loggers.py +62 -43
- nextrec/basic/metrics.py +268 -119
- nextrec/basic/model.py +1082 -400
- nextrec/basic/session.py +10 -3
- nextrec/cli.py +498 -0
- nextrec/data/__init__.py +19 -25
- nextrec/data/batch_utils.py +11 -3
- nextrec/data/data_processing.py +51 -45
- nextrec/data/data_utils.py +26 -15
- nextrec/data/dataloader.py +272 -95
- nextrec/data/preprocessor.py +320 -199
- nextrec/loss/listwise.py +17 -9
- nextrec/loss/loss_utils.py +7 -8
- nextrec/loss/pairwise.py +2 -0
- nextrec/loss/pointwise.py +30 -12
- nextrec/models/generative/hstu.py +103 -38
- nextrec/models/match/dssm.py +82 -68
- nextrec/models/match/dssm_v2.py +72 -57
- nextrec/models/match/mind.py +175 -107
- nextrec/models/match/sdm.py +104 -87
- nextrec/models/match/youtube_dnn.py +73 -59
- nextrec/models/multi_task/esmm.py +53 -37
- nextrec/models/multi_task/mmoe.py +64 -45
- nextrec/models/multi_task/ple.py +101 -48
- nextrec/models/multi_task/poso.py +113 -36
- nextrec/models/multi_task/share_bottom.py +48 -35
- nextrec/models/ranking/afm.py +72 -37
- nextrec/models/ranking/autoint.py +72 -55
- nextrec/models/ranking/dcn.py +55 -35
- nextrec/models/ranking/dcn_v2.py +64 -23
- nextrec/models/ranking/deepfm.py +32 -22
- nextrec/models/ranking/dien.py +155 -99
- nextrec/models/ranking/din.py +85 -57
- nextrec/models/ranking/fibinet.py +52 -32
- nextrec/models/ranking/fm.py +29 -23
- nextrec/models/ranking/masknet.py +91 -29
- nextrec/models/ranking/pnn.py +31 -28
- nextrec/models/ranking/widedeep.py +34 -26
- nextrec/models/ranking/xdeepfm.py +60 -38
- nextrec/utils/__init__.py +59 -34
- nextrec/utils/config.py +490 -0
- nextrec/utils/device.py +30 -20
- nextrec/utils/distributed.py +36 -9
- nextrec/utils/embedding.py +1 -0
- nextrec/utils/feature.py +1 -0
- nextrec/utils/file.py +32 -11
- nextrec/utils/initializer.py +61 -16
- nextrec/utils/optimizer.py +25 -9
- nextrec/utils/synthetic_data.py +283 -165
- nextrec/utils/tensor.py +24 -13
- {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/METADATA +4 -4
- nextrec-0.4.2.dist-info/RECORD +69 -0
- nextrec-0.4.2.dist-info/entry_points.txt +2 -0
- nextrec-0.4.1.dist-info/RECORD +0 -66
- {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
- {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/layers.py
CHANGED
|
@@ -5,19 +5,21 @@ Date: create on 27/10/2025
|
|
|
5
5
|
Checkpoint: edit on 29/11/2025
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
|
-
from __future__ import annotations
|
|
9
8
|
|
|
10
|
-
from
|
|
11
|
-
from collections import OrderedDict
|
|
9
|
+
from __future__ import annotations
|
|
12
10
|
|
|
13
11
|
import torch
|
|
14
12
|
import torch.nn as nn
|
|
15
13
|
import torch.nn.functional as F
|
|
16
14
|
|
|
15
|
+
from itertools import combinations
|
|
16
|
+
from collections import OrderedDict
|
|
17
|
+
|
|
17
18
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
18
19
|
from nextrec.utils.initializer import get_initializer
|
|
19
20
|
from nextrec.basic.activation import activation_layer
|
|
20
21
|
|
|
22
|
+
|
|
21
23
|
class PredictionLayer(nn.Module):
|
|
22
24
|
def __init__(
|
|
23
25
|
self,
|
|
@@ -30,7 +32,7 @@ class PredictionLayer(nn.Module):
|
|
|
30
32
|
self.task_types = [task_type] if isinstance(task_type, str) else list(task_type)
|
|
31
33
|
if len(self.task_types) == 0:
|
|
32
34
|
raise ValueError("At least one task_type must be specified.")
|
|
33
|
-
|
|
35
|
+
|
|
34
36
|
if task_dims is None:
|
|
35
37
|
dims = [1] * len(self.task_types)
|
|
36
38
|
elif isinstance(task_dims, int):
|
|
@@ -38,7 +40,9 @@ class PredictionLayer(nn.Module):
|
|
|
38
40
|
else:
|
|
39
41
|
dims = list(task_dims)
|
|
40
42
|
if len(dims) not in (1, len(self.task_types)):
|
|
41
|
-
raise ValueError(
|
|
43
|
+
raise ValueError(
|
|
44
|
+
"[PredictionLayer Error]: task_dims must be None, a single int (shared), or a sequence of the same length as task_type."
|
|
45
|
+
)
|
|
42
46
|
if len(dims) == 1 and len(self.task_types) > 1:
|
|
43
47
|
dims = dims * len(self.task_types)
|
|
44
48
|
self.task_dims = dims
|
|
@@ -62,27 +66,33 @@ class PredictionLayer(nn.Module):
|
|
|
62
66
|
if x.dim() == 1:
|
|
63
67
|
x = x.unsqueeze(0) # (1 * total_dim)
|
|
64
68
|
if x.shape[-1] != self.total_dim:
|
|
65
|
-
raise ValueError(
|
|
69
|
+
raise ValueError(
|
|
70
|
+
f"[PredictionLayer Error]: Input last dimension ({x.shape[-1]}) does not match expected total dimension ({self.total_dim})."
|
|
71
|
+
)
|
|
66
72
|
logits = x if self.bias is None else x + self.bias
|
|
67
73
|
outputs = []
|
|
68
74
|
for task_type, (start, end) in zip(self.task_types, self._task_slices):
|
|
69
|
-
task_logits = logits[..., start:end]
|
|
75
|
+
task_logits = logits[..., start:end] # logits for the current task
|
|
70
76
|
if self.return_logits:
|
|
71
77
|
outputs.append(task_logits)
|
|
72
78
|
continue
|
|
73
79
|
task = task_type.lower()
|
|
74
|
-
if task ==
|
|
75
|
-
|
|
76
|
-
elif task ==
|
|
77
|
-
|
|
78
|
-
elif task ==
|
|
79
|
-
|
|
80
|
+
if task == "binary":
|
|
81
|
+
outputs.append(torch.sigmoid(task_logits))
|
|
82
|
+
elif task == "regression":
|
|
83
|
+
outputs.append(task_logits)
|
|
84
|
+
elif task == "multiclass":
|
|
85
|
+
outputs.append(torch.softmax(task_logits, dim=-1))
|
|
80
86
|
else:
|
|
81
|
-
raise ValueError(
|
|
82
|
-
|
|
83
|
-
|
|
87
|
+
raise ValueError(
|
|
88
|
+
f"[PredictionLayer Error]: Unsupported task_type '{task_type}'."
|
|
89
|
+
)
|
|
90
|
+
result = torch.cat(
|
|
91
|
+
outputs, dim=-1
|
|
92
|
+
) # single: (N,1), multi-task/multi-class: (N,total_dim)
|
|
84
93
|
return result
|
|
85
94
|
|
|
95
|
+
|
|
86
96
|
class EmbeddingLayer(nn.Module):
|
|
87
97
|
def __init__(self, features: list):
|
|
88
98
|
super().__init__()
|
|
@@ -96,20 +106,30 @@ class EmbeddingLayer(nn.Module):
|
|
|
96
106
|
if feature.embedding_name in self.embed_dict:
|
|
97
107
|
continue
|
|
98
108
|
if getattr(feature, "pretrained_weight", None) is not None:
|
|
99
|
-
weight = feature.pretrained_weight
|
|
100
|
-
if weight.shape != (feature.vocab_size, feature.embedding_dim):
|
|
101
|
-
raise ValueError(f"[EmbeddingLayer Error]: Pretrained weight for '{feature.embedding_name}' has shape {weight.shape}, expected ({feature.vocab_size}, {feature.embedding_dim}).")
|
|
102
|
-
embedding = nn.Embedding.from_pretrained(embeddings=weight, freeze=feature.freeze_pretrained, padding_idx=feature.padding_idx)
|
|
103
|
-
embedding.weight.requires_grad = feature.trainable and not feature.freeze_pretrained
|
|
109
|
+
weight = feature.pretrained_weight # type: ignore[assignment]
|
|
110
|
+
if weight.shape != (feature.vocab_size, feature.embedding_dim): # type: ignore[assignment]
|
|
111
|
+
raise ValueError(f"[EmbeddingLayer Error]: Pretrained weight for '{feature.embedding_name}' has shape {weight.shape}, expected ({feature.vocab_size}, {feature.embedding_dim}).") # type: ignore[assignment]
|
|
112
|
+
embedding = nn.Embedding.from_pretrained(embeddings=weight, freeze=feature.freeze_pretrained, padding_idx=feature.padding_idx) # type: ignore[assignment]
|
|
113
|
+
embedding.weight.requires_grad = feature.trainable and not feature.freeze_pretrained # type: ignore[assignment]
|
|
104
114
|
else:
|
|
105
|
-
embedding = nn.Embedding(
|
|
115
|
+
embedding = nn.Embedding(
|
|
116
|
+
num_embeddings=feature.vocab_size,
|
|
117
|
+
embedding_dim=feature.embedding_dim,
|
|
118
|
+
padding_idx=feature.padding_idx,
|
|
119
|
+
)
|
|
106
120
|
embedding.weight.requires_grad = feature.trainable
|
|
107
|
-
initialization = get_initializer(
|
|
121
|
+
initialization = get_initializer(
|
|
122
|
+
init_type=feature.init_type,
|
|
123
|
+
activation="linear",
|
|
124
|
+
param=feature.init_params,
|
|
125
|
+
)
|
|
108
126
|
initialization(embedding.weight)
|
|
109
127
|
self.embed_dict[feature.embedding_name] = embedding
|
|
110
128
|
elif isinstance(feature, DenseFeature):
|
|
111
129
|
if not feature.use_embedding:
|
|
112
|
-
self.dense_input_dims[feature.name] = max(
|
|
130
|
+
self.dense_input_dims[feature.name] = max(
|
|
131
|
+
int(getattr(feature, "input_dim", 1)), 1
|
|
132
|
+
)
|
|
113
133
|
continue
|
|
114
134
|
if feature.name in self.dense_transforms:
|
|
115
135
|
continue
|
|
@@ -121,7 +141,9 @@ class EmbeddingLayer(nn.Module):
|
|
|
121
141
|
self.dense_transforms[feature.name] = dense_linear
|
|
122
142
|
self.dense_input_dims[feature.name] = in_dim
|
|
123
143
|
else:
|
|
124
|
-
raise TypeError(
|
|
144
|
+
raise TypeError(
|
|
145
|
+
f"[EmbeddingLayer Error]: Unsupported feature type: {type(feature)}"
|
|
146
|
+
)
|
|
125
147
|
self.output_dim = self.compute_output_dim()
|
|
126
148
|
|
|
127
149
|
def forward(
|
|
@@ -153,7 +175,9 @@ class EmbeddingLayer(nn.Module):
|
|
|
153
175
|
elif feature.combiner == "concat":
|
|
154
176
|
pooling_layer = ConcatPooling()
|
|
155
177
|
else:
|
|
156
|
-
raise ValueError(
|
|
178
|
+
raise ValueError(
|
|
179
|
+
f"[EmbeddingLayer Error]: Unknown combiner for {feature.name}: {feature.combiner}"
|
|
180
|
+
)
|
|
157
181
|
feature_mask = InputMask()(x, feature, seq_input)
|
|
158
182
|
sparse_embeds.append(pooling_layer(seq_emb, feature_mask).unsqueeze(1))
|
|
159
183
|
|
|
@@ -168,9 +192,11 @@ class EmbeddingLayer(nn.Module):
|
|
|
168
192
|
if dense_embeds:
|
|
169
193
|
pieces.append(torch.cat(dense_embeds, dim=1))
|
|
170
194
|
if not pieces:
|
|
171
|
-
raise ValueError(
|
|
195
|
+
raise ValueError(
|
|
196
|
+
"[EmbeddingLayer Error]: No input features found for EmbeddingLayer."
|
|
197
|
+
)
|
|
172
198
|
return pieces[0] if len(pieces) == 1 else torch.cat(pieces, dim=1)
|
|
173
|
-
|
|
199
|
+
|
|
174
200
|
# squeeze_dim=False requires embeddings with identical last dimension
|
|
175
201
|
output_embeddings = list(sparse_embeds)
|
|
176
202
|
if dense_embeds:
|
|
@@ -178,36 +204,53 @@ class EmbeddingLayer(nn.Module):
|
|
|
178
204
|
target_dim = output_embeddings[0].shape[-1]
|
|
179
205
|
for emb in dense_embeds:
|
|
180
206
|
if emb.shape[-1] != target_dim:
|
|
181
|
-
raise ValueError(
|
|
207
|
+
raise ValueError(
|
|
208
|
+
f"[EmbeddingLayer Error]: squeeze_dim=False requires all dense feature dimensions to match the embedding dimension of sparse/sequence features ({target_dim}), but got {emb.shape[-1]}."
|
|
209
|
+
)
|
|
182
210
|
output_embeddings.extend(emb.unsqueeze(1) for emb in dense_embeds)
|
|
183
211
|
else:
|
|
184
212
|
dims = {emb.shape[-1] for emb in dense_embeds}
|
|
185
213
|
if len(dims) != 1:
|
|
186
|
-
raise ValueError(
|
|
214
|
+
raise ValueError(
|
|
215
|
+
f"[EmbeddingLayer Error]: squeeze_dim=False requires all dense features to have identical dimensions when no sparse/sequence features are present, but got dimensions {dims}."
|
|
216
|
+
)
|
|
187
217
|
output_embeddings = [emb.unsqueeze(1) for emb in dense_embeds]
|
|
188
218
|
if not output_embeddings:
|
|
189
|
-
raise ValueError(
|
|
219
|
+
raise ValueError(
|
|
220
|
+
"[EmbeddingLayer Error]: squeeze_dim=False requires at least one sparse/sequence feature or dense features with identical projected dimensions."
|
|
221
|
+
)
|
|
190
222
|
return torch.cat(output_embeddings, dim=1)
|
|
191
223
|
|
|
192
|
-
def project_dense(
|
|
224
|
+
def project_dense(
|
|
225
|
+
self, feature: DenseFeature, x: dict[str, torch.Tensor]
|
|
226
|
+
) -> torch.Tensor:
|
|
193
227
|
if feature.name not in x:
|
|
194
|
-
raise KeyError(
|
|
228
|
+
raise KeyError(
|
|
229
|
+
f"[EmbeddingLayer Error]:Dense feature '{feature.name}' is missing from input."
|
|
230
|
+
)
|
|
195
231
|
value = x[feature.name].float()
|
|
196
232
|
if value.dim() == 1:
|
|
197
233
|
value = value.unsqueeze(-1)
|
|
198
234
|
else:
|
|
199
235
|
value = value.view(value.size(0), -1)
|
|
200
|
-
expected_in_dim = self.dense_input_dims.get(
|
|
236
|
+
expected_in_dim = self.dense_input_dims.get(
|
|
237
|
+
feature.name, max(int(getattr(feature, "input_dim", 1)), 1)
|
|
238
|
+
)
|
|
201
239
|
if value.shape[1] != expected_in_dim:
|
|
202
|
-
raise ValueError(
|
|
240
|
+
raise ValueError(
|
|
241
|
+
f"[EmbeddingLayer Error]:Dense feature '{feature.name}' expects {expected_in_dim} inputs but got {value.shape[1]}."
|
|
242
|
+
)
|
|
203
243
|
if not feature.use_embedding:
|
|
204
|
-
return value
|
|
244
|
+
return value
|
|
205
245
|
dense_layer = self.dense_transforms[feature.name]
|
|
206
246
|
return dense_layer(value)
|
|
207
247
|
|
|
208
|
-
def compute_output_dim(
|
|
248
|
+
def compute_output_dim(
|
|
249
|
+
self,
|
|
250
|
+
features: list[DenseFeature | SequenceFeature | SparseFeature] | None = None,
|
|
251
|
+
) -> int:
|
|
209
252
|
candidates = list(features) if features is not None else self.features
|
|
210
|
-
unique_feats = OrderedDict((feat.name, feat) for feat in candidates)
|
|
253
|
+
unique_feats = OrderedDict((feat.name, feat) for feat in candidates) # type: ignore[assignment]
|
|
211
254
|
dim = 0
|
|
212
255
|
for feat in unique_feats.values():
|
|
213
256
|
if isinstance(feat, DenseFeature):
|
|
@@ -218,28 +261,34 @@ class EmbeddingLayer(nn.Module):
|
|
|
218
261
|
elif isinstance(feat, SequenceFeature) and feat.combiner == "concat":
|
|
219
262
|
dim += feat.embedding_dim * feat.max_len
|
|
220
263
|
else:
|
|
221
|
-
dim += feat.embedding_dim
|
|
264
|
+
dim += feat.embedding_dim # type: ignore[assignment]
|
|
222
265
|
return dim
|
|
223
266
|
|
|
224
267
|
def get_input_dim(self, features: list[object] | None = None) -> int:
|
|
225
|
-
return self.compute_output_dim(features)
|
|
268
|
+
return self.compute_output_dim(features) # type: ignore[assignment]
|
|
226
269
|
|
|
227
270
|
@property
|
|
228
271
|
def input_dim(self) -> int:
|
|
229
272
|
return self.output_dim
|
|
230
273
|
|
|
274
|
+
|
|
231
275
|
class InputMask(nn.Module):
|
|
232
276
|
def __init__(self):
|
|
233
277
|
super().__init__()
|
|
234
278
|
|
|
235
|
-
def forward(
|
|
279
|
+
def forward(
|
|
280
|
+
self,
|
|
281
|
+
x: dict[str, torch.Tensor],
|
|
282
|
+
feature: SequenceFeature,
|
|
283
|
+
seq_tensor: torch.Tensor | None = None,
|
|
284
|
+
):
|
|
236
285
|
if seq_tensor is not None:
|
|
237
286
|
values = seq_tensor
|
|
238
287
|
else:
|
|
239
288
|
values = x[feature.name]
|
|
240
289
|
values = values.long()
|
|
241
290
|
padding_idx = feature.padding_idx if feature.padding_idx is not None else 0
|
|
242
|
-
mask =
|
|
291
|
+
mask = values != padding_idx
|
|
243
292
|
|
|
244
293
|
if mask.dim() == 1:
|
|
245
294
|
# [B] -> [B, 1, 1]
|
|
@@ -253,14 +302,14 @@ class InputMask(nn.Module):
|
|
|
253
302
|
if mask.size(1) != 1 and mask.size(2) == 1:
|
|
254
303
|
mask = mask.squeeze(-1).unsqueeze(1)
|
|
255
304
|
else:
|
|
256
|
-
raise ValueError(
|
|
305
|
+
raise ValueError(
|
|
306
|
+
f"InputMask only supports 1D/2D/3D tensors, got shape {values.shape}"
|
|
307
|
+
)
|
|
257
308
|
return mask.float()
|
|
258
309
|
|
|
310
|
+
|
|
259
311
|
class LR(nn.Module):
|
|
260
|
-
def __init__(
|
|
261
|
-
self,
|
|
262
|
-
input_dim: int,
|
|
263
|
-
sigmoid: bool = False):
|
|
312
|
+
def __init__(self, input_dim: int, sigmoid: bool = False):
|
|
264
313
|
super().__init__()
|
|
265
314
|
self.sigmoid = sigmoid
|
|
266
315
|
self.fc = nn.Linear(input_dim, 1, bias=True)
|
|
@@ -271,18 +320,24 @@ class LR(nn.Module):
|
|
|
271
320
|
else:
|
|
272
321
|
return self.fc(x)
|
|
273
322
|
|
|
323
|
+
|
|
274
324
|
class ConcatPooling(nn.Module):
|
|
275
325
|
def __init__(self):
|
|
276
326
|
super().__init__()
|
|
277
327
|
|
|
278
|
-
def forward(
|
|
279
|
-
|
|
328
|
+
def forward(
|
|
329
|
+
self, x: torch.Tensor, mask: torch.Tensor | None = None
|
|
330
|
+
) -> torch.Tensor:
|
|
331
|
+
return x.flatten(start_dim=1, end_dim=2)
|
|
332
|
+
|
|
280
333
|
|
|
281
334
|
class AveragePooling(nn.Module):
|
|
282
335
|
def __init__(self):
|
|
283
336
|
super().__init__()
|
|
284
337
|
|
|
285
|
-
def forward(
|
|
338
|
+
def forward(
|
|
339
|
+
self, x: torch.Tensor, mask: torch.Tensor | None = None
|
|
340
|
+
) -> torch.Tensor:
|
|
286
341
|
if mask is None:
|
|
287
342
|
return torch.mean(x, dim=1)
|
|
288
343
|
else:
|
|
@@ -290,24 +345,29 @@ class AveragePooling(nn.Module):
|
|
|
290
345
|
non_padding_length = mask.sum(dim=-1)
|
|
291
346
|
return sum_pooling_matrix / (non_padding_length.float() + 1e-16)
|
|
292
347
|
|
|
348
|
+
|
|
293
349
|
class SumPooling(nn.Module):
|
|
294
350
|
def __init__(self):
|
|
295
351
|
super().__init__()
|
|
296
352
|
|
|
297
|
-
def forward(
|
|
353
|
+
def forward(
|
|
354
|
+
self, x: torch.Tensor, mask: torch.Tensor | None = None
|
|
355
|
+
) -> torch.Tensor:
|
|
298
356
|
if mask is None:
|
|
299
357
|
return torch.sum(x, dim=1)
|
|
300
358
|
else:
|
|
301
359
|
return torch.bmm(mask, x).squeeze(1)
|
|
302
360
|
|
|
361
|
+
|
|
303
362
|
class MLP(nn.Module):
|
|
304
363
|
def __init__(
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
364
|
+
self,
|
|
365
|
+
input_dim: int,
|
|
366
|
+
output_layer: bool = True,
|
|
367
|
+
dims: list[int] | None = None,
|
|
368
|
+
dropout: float = 0.0,
|
|
369
|
+
activation: str = "relu",
|
|
370
|
+
):
|
|
311
371
|
super().__init__()
|
|
312
372
|
if dims is None:
|
|
313
373
|
dims = []
|
|
@@ -320,29 +380,32 @@ class MLP(nn.Module):
|
|
|
320
380
|
layers.append(activation_layer(activation))
|
|
321
381
|
layers.append(nn.Dropout(p=dropout))
|
|
322
382
|
current_dim = i_dim
|
|
323
|
-
|
|
383
|
+
|
|
324
384
|
if output_layer:
|
|
325
385
|
layers.append(nn.Linear(current_dim, 1))
|
|
326
386
|
self.output_dim = 1
|
|
327
387
|
else:
|
|
328
|
-
self.output_dim = current_dim
|
|
388
|
+
self.output_dim = current_dim
|
|
329
389
|
self.mlp = nn.Sequential(*layers)
|
|
390
|
+
|
|
330
391
|
def forward(self, x):
|
|
331
392
|
return self.mlp(x)
|
|
332
|
-
|
|
393
|
+
|
|
394
|
+
|
|
333
395
|
class FM(nn.Module):
|
|
334
396
|
def __init__(self, reduce_sum: bool = True):
|
|
335
397
|
super().__init__()
|
|
336
398
|
self.reduce_sum = reduce_sum
|
|
337
399
|
|
|
338
400
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
339
|
-
square_of_sum = torch.sum(x, dim=1)**2
|
|
401
|
+
square_of_sum = torch.sum(x, dim=1) ** 2
|
|
340
402
|
sum_of_square = torch.sum(x**2, dim=1)
|
|
341
403
|
ix = square_of_sum - sum_of_square
|
|
342
404
|
if self.reduce_sum:
|
|
343
405
|
ix = torch.sum(ix, dim=1, keepdim=True)
|
|
344
406
|
return 0.5 * ix
|
|
345
407
|
|
|
408
|
+
|
|
346
409
|
class CrossLayer(nn.Module):
|
|
347
410
|
def __init__(self, input_dim: int):
|
|
348
411
|
super(CrossLayer, self).__init__()
|
|
@@ -353,60 +416,74 @@ class CrossLayer(nn.Module):
|
|
|
353
416
|
x = self.w(x_i) * x_0 + self.b
|
|
354
417
|
return x
|
|
355
418
|
|
|
419
|
+
|
|
356
420
|
class SENETLayer(nn.Module):
|
|
357
|
-
def __init__(
|
|
358
|
-
self,
|
|
359
|
-
num_fields: int,
|
|
360
|
-
reduction_ratio: int = 3):
|
|
421
|
+
def __init__(self, num_fields: int, reduction_ratio: int = 3):
|
|
361
422
|
super(SENETLayer, self).__init__()
|
|
362
|
-
reduced_size = max(1, int(num_fields/ reduction_ratio))
|
|
363
|
-
self.mlp = nn.Sequential(
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
423
|
+
reduced_size = max(1, int(num_fields / reduction_ratio))
|
|
424
|
+
self.mlp = nn.Sequential(
|
|
425
|
+
nn.Linear(num_fields, reduced_size, bias=False),
|
|
426
|
+
nn.ReLU(),
|
|
427
|
+
nn.Linear(reduced_size, num_fields, bias=False),
|
|
428
|
+
nn.ReLU(),
|
|
429
|
+
)
|
|
430
|
+
|
|
367
431
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
368
432
|
z = torch.mean(x, dim=-1, out=None)
|
|
369
433
|
a = self.mlp(z)
|
|
370
|
-
v = x*a.unsqueeze(-1)
|
|
434
|
+
v = x * a.unsqueeze(-1)
|
|
371
435
|
return v
|
|
372
436
|
|
|
437
|
+
|
|
373
438
|
class BiLinearInteractionLayer(nn.Module):
|
|
374
439
|
def __init__(
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
num_fields: int,
|
|
378
|
-
bilinear_type: str = "field_interaction"):
|
|
440
|
+
self, input_dim: int, num_fields: int, bilinear_type: str = "field_interaction"
|
|
441
|
+
):
|
|
379
442
|
super(BiLinearInteractionLayer, self).__init__()
|
|
380
443
|
self.bilinear_type = bilinear_type
|
|
381
444
|
if self.bilinear_type == "field_all":
|
|
382
445
|
self.bilinear_layer = nn.Linear(input_dim, input_dim, bias=False)
|
|
383
446
|
elif self.bilinear_type == "field_each":
|
|
384
|
-
self.bilinear_layer = nn.ModuleList(
|
|
447
|
+
self.bilinear_layer = nn.ModuleList(
|
|
448
|
+
[nn.Linear(input_dim, input_dim, bias=False) for i in range(num_fields)]
|
|
449
|
+
)
|
|
385
450
|
elif self.bilinear_type == "field_interaction":
|
|
386
|
-
self.bilinear_layer = nn.ModuleList(
|
|
451
|
+
self.bilinear_layer = nn.ModuleList(
|
|
452
|
+
[
|
|
453
|
+
nn.Linear(input_dim, input_dim, bias=False)
|
|
454
|
+
for i, j in combinations(range(num_fields), 2)
|
|
455
|
+
]
|
|
456
|
+
)
|
|
387
457
|
else:
|
|
388
458
|
raise NotImplementedError()
|
|
389
459
|
|
|
390
460
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
391
461
|
feature_emb = torch.split(x, 1, dim=1)
|
|
392
462
|
if self.bilinear_type == "field_all":
|
|
393
|
-
bilinear_list = [
|
|
463
|
+
bilinear_list = [
|
|
464
|
+
self.bilinear_layer(v_i) * v_j
|
|
465
|
+
for v_i, v_j in combinations(feature_emb, 2)
|
|
466
|
+
]
|
|
394
467
|
elif self.bilinear_type == "field_each":
|
|
395
|
-
bilinear_list = [self.bilinear_layer[i](feature_emb[i])*feature_emb[j] for i,j in combinations(range(len(feature_emb)), 2)] # type: ignore[assignment]
|
|
468
|
+
bilinear_list = [self.bilinear_layer[i](feature_emb[i]) * feature_emb[j] for i, j in combinations(range(len(feature_emb)), 2)] # type: ignore[assignment]
|
|
396
469
|
elif self.bilinear_type == "field_interaction":
|
|
397
|
-
bilinear_list = [self.bilinear_layer[i](v[0])*v[1] for i,v in enumerate(combinations(feature_emb, 2))]
|
|
470
|
+
bilinear_list = [self.bilinear_layer[i](v[0]) * v[1] for i, v in enumerate(combinations(feature_emb, 2))] # type: ignore[assignment]
|
|
398
471
|
return torch.cat(bilinear_list, dim=1)
|
|
399
472
|
|
|
473
|
+
|
|
400
474
|
class MultiHeadSelfAttention(nn.Module):
|
|
401
475
|
def __init__(
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
476
|
+
self,
|
|
477
|
+
embedding_dim: int,
|
|
478
|
+
num_heads: int = 2,
|
|
479
|
+
dropout: float = 0.0,
|
|
480
|
+
use_residual: bool = True,
|
|
481
|
+
):
|
|
407
482
|
super().__init__()
|
|
408
483
|
if embedding_dim % num_heads != 0:
|
|
409
|
-
raise ValueError(
|
|
484
|
+
raise ValueError(
|
|
485
|
+
f"[MultiHeadSelfAttention Error]: embedding_dim ({embedding_dim}) must be divisible by num_heads ({num_heads})"
|
|
486
|
+
)
|
|
410
487
|
self.embedding_dim = embedding_dim
|
|
411
488
|
self.num_heads = num_heads
|
|
412
489
|
self.head_dim = embedding_dim // num_heads
|
|
@@ -417,24 +494,34 @@ class MultiHeadSelfAttention(nn.Module):
|
|
|
417
494
|
if self.use_residual:
|
|
418
495
|
self.W_Res = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
419
496
|
self.dropout = nn.Dropout(dropout)
|
|
420
|
-
|
|
497
|
+
|
|
421
498
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
422
499
|
batch_size, num_fields, _ = x.shape
|
|
423
500
|
Q = self.W_Q(x) # [batch_size, num_fields, embedding_dim]
|
|
424
501
|
K = self.W_K(x)
|
|
425
502
|
V = self.W_V(x)
|
|
426
503
|
# Split into multiple heads: [batch_size, num_heads, num_fields, head_dim]
|
|
427
|
-
Q = Q.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
|
|
428
|
-
|
|
429
|
-
|
|
504
|
+
Q = Q.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
|
|
505
|
+
1, 2
|
|
506
|
+
)
|
|
507
|
+
K = K.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
|
|
508
|
+
1, 2
|
|
509
|
+
)
|
|
510
|
+
V = V.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
|
|
511
|
+
1, 2
|
|
512
|
+
)
|
|
430
513
|
# Attention scores
|
|
431
|
-
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim
|
|
514
|
+
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim**0.5)
|
|
432
515
|
attention_weights = F.softmax(scores, dim=-1)
|
|
433
516
|
attention_weights = self.dropout(attention_weights)
|
|
434
|
-
attention_output = torch.matmul(
|
|
517
|
+
attention_output = torch.matmul(
|
|
518
|
+
attention_weights, V
|
|
519
|
+
) # [batch_size, num_heads, num_fields, head_dim]
|
|
435
520
|
# Concatenate heads
|
|
436
521
|
attention_output = attention_output.transpose(1, 2).contiguous()
|
|
437
|
-
attention_output = attention_output.view(
|
|
522
|
+
attention_output = attention_output.view(
|
|
523
|
+
batch_size, num_fields, self.embedding_dim
|
|
524
|
+
)
|
|
438
525
|
# Residual connection
|
|
439
526
|
if self.use_residual:
|
|
440
527
|
output = attention_output + self.W_Res(x)
|
|
@@ -443,17 +530,20 @@ class MultiHeadSelfAttention(nn.Module):
|
|
|
443
530
|
output = F.relu(output)
|
|
444
531
|
return output
|
|
445
532
|
|
|
533
|
+
|
|
446
534
|
class AttentionPoolingLayer(nn.Module):
|
|
447
535
|
"""
|
|
448
536
|
Attention pooling layer for DIN/DIEN
|
|
449
537
|
Computes attention weights between query (candidate item) and keys (user behavior sequence)
|
|
450
538
|
"""
|
|
539
|
+
|
|
451
540
|
def __init__(
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
541
|
+
self,
|
|
542
|
+
embedding_dim: int,
|
|
543
|
+
hidden_units: list = [80, 40],
|
|
544
|
+
activation: str = "sigmoid",
|
|
545
|
+
use_softmax: bool = True,
|
|
546
|
+
):
|
|
457
547
|
super().__init__()
|
|
458
548
|
self.embedding_dim = embedding_dim
|
|
459
549
|
self.use_softmax = use_softmax
|
|
@@ -467,8 +557,14 @@ class AttentionPoolingLayer(nn.Module):
|
|
|
467
557
|
input_dim = hidden_unit
|
|
468
558
|
layers.append(nn.Linear(input_dim, 1))
|
|
469
559
|
self.attention_net = nn.Sequential(*layers)
|
|
470
|
-
|
|
471
|
-
def forward(
|
|
560
|
+
|
|
561
|
+
def forward(
|
|
562
|
+
self,
|
|
563
|
+
query: torch.Tensor,
|
|
564
|
+
keys: torch.Tensor,
|
|
565
|
+
keys_length: torch.Tensor | None = None,
|
|
566
|
+
mask: torch.Tensor | None = None,
|
|
567
|
+
):
|
|
472
568
|
"""
|
|
473
569
|
Args:
|
|
474
570
|
query: [batch_size, embedding_dim] - candidate item embedding
|
|
@@ -479,28 +575,46 @@ class AttentionPoolingLayer(nn.Module):
|
|
|
479
575
|
output: [batch_size, embedding_dim] - attention pooled representation
|
|
480
576
|
"""
|
|
481
577
|
batch_size, sequence_length, embedding_dim = keys.shape
|
|
482
|
-
assert query.shape == (
|
|
578
|
+
assert query.shape == (
|
|
579
|
+
batch_size,
|
|
580
|
+
embedding_dim,
|
|
581
|
+
), f"query shape {query.shape} != ({batch_size}, {embedding_dim})"
|
|
483
582
|
if mask is None and keys_length is not None:
|
|
484
583
|
# keys_length: (batch_size,)
|
|
485
584
|
device = keys.device
|
|
486
|
-
seq_range = torch.arange(sequence_length, device=device).unsqueeze(
|
|
585
|
+
seq_range = torch.arange(sequence_length, device=device).unsqueeze(
|
|
586
|
+
0
|
|
587
|
+
) # (1, sequence_length)
|
|
487
588
|
mask = (seq_range < keys_length.unsqueeze(1)).unsqueeze(-1).float()
|
|
488
589
|
if mask is not None:
|
|
489
590
|
if mask.dim() == 2:
|
|
490
591
|
# (B, L)
|
|
491
592
|
mask = mask.unsqueeze(-1)
|
|
492
|
-
elif
|
|
593
|
+
elif (
|
|
594
|
+
mask.dim() == 3
|
|
595
|
+
and mask.shape[1] == 1
|
|
596
|
+
and mask.shape[2] == sequence_length
|
|
597
|
+
):
|
|
493
598
|
# (B, 1, L) -> (B, L, 1)
|
|
494
599
|
mask = mask.transpose(1, 2)
|
|
495
|
-
elif
|
|
600
|
+
elif (
|
|
601
|
+
mask.dim() == 3
|
|
602
|
+
and mask.shape[1] == sequence_length
|
|
603
|
+
and mask.shape[2] == 1
|
|
604
|
+
):
|
|
496
605
|
pass
|
|
497
606
|
else:
|
|
498
|
-
raise ValueError(
|
|
607
|
+
raise ValueError(
|
|
608
|
+
f"[AttentionPoolingLayer Error]: Unsupported mask shape: {mask.shape}"
|
|
609
|
+
)
|
|
499
610
|
mask = mask.to(keys.dtype)
|
|
500
611
|
# Expand query to (B, L, D)
|
|
501
612
|
query_expanded = query.unsqueeze(1).expand(-1, sequence_length, -1)
|
|
502
613
|
# [query, key, query-key, query*key] -> (B, L, 4D)
|
|
503
|
-
attention_input = torch.cat(
|
|
614
|
+
attention_input = torch.cat(
|
|
615
|
+
[query_expanded, keys, query_expanded - keys, query_expanded * keys],
|
|
616
|
+
dim=-1,
|
|
617
|
+
)
|
|
504
618
|
attention_scores = self.attention_net(attention_input)
|
|
505
619
|
if mask is not None:
|
|
506
620
|
attention_scores = attention_scores.masked_fill(mask == 0, -1e9)
|