nextrec 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +1 -1
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -5
- nextrec/basic/callback.py +1 -0
- nextrec/basic/features.py +30 -22
- nextrec/basic/layers.py +250 -112
- nextrec/basic/loggers.py +63 -44
- nextrec/basic/metrics.py +270 -120
- nextrec/basic/model.py +1084 -402
- nextrec/basic/session.py +10 -3
- nextrec/cli.py +492 -0
- nextrec/data/__init__.py +19 -25
- nextrec/data/batch_utils.py +11 -3
- nextrec/data/data_processing.py +51 -45
- nextrec/data/data_utils.py +26 -15
- nextrec/data/dataloader.py +273 -96
- nextrec/data/preprocessor.py +320 -199
- nextrec/loss/listwise.py +17 -9
- nextrec/loss/loss_utils.py +7 -8
- nextrec/loss/pairwise.py +2 -0
- nextrec/loss/pointwise.py +30 -12
- nextrec/models/generative/hstu.py +103 -38
- nextrec/models/match/dssm.py +82 -68
- nextrec/models/match/dssm_v2.py +72 -57
- nextrec/models/match/mind.py +175 -107
- nextrec/models/match/sdm.py +104 -87
- nextrec/models/match/youtube_dnn.py +73 -59
- nextrec/models/multi_task/esmm.py +69 -46
- nextrec/models/multi_task/mmoe.py +91 -53
- nextrec/models/multi_task/ple.py +117 -58
- nextrec/models/multi_task/poso.py +163 -55
- nextrec/models/multi_task/share_bottom.py +63 -36
- nextrec/models/ranking/afm.py +80 -45
- nextrec/models/ranking/autoint.py +74 -57
- nextrec/models/ranking/dcn.py +110 -48
- nextrec/models/ranking/dcn_v2.py +265 -45
- nextrec/models/ranking/deepfm.py +39 -24
- nextrec/models/ranking/dien.py +335 -146
- nextrec/models/ranking/din.py +158 -92
- nextrec/models/ranking/fibinet.py +134 -52
- nextrec/models/ranking/fm.py +68 -26
- nextrec/models/ranking/masknet.py +95 -33
- nextrec/models/ranking/pnn.py +128 -58
- nextrec/models/ranking/widedeep.py +40 -28
- nextrec/models/ranking/xdeepfm.py +67 -40
- nextrec/utils/__init__.py +59 -34
- nextrec/utils/config.py +496 -0
- nextrec/utils/device.py +30 -20
- nextrec/utils/distributed.py +36 -9
- nextrec/utils/embedding.py +1 -0
- nextrec/utils/feature.py +1 -0
- nextrec/utils/file.py +33 -11
- nextrec/utils/initializer.py +61 -16
- nextrec/utils/model.py +22 -0
- nextrec/utils/optimizer.py +25 -9
- nextrec/utils/synthetic_data.py +283 -165
- nextrec/utils/tensor.py +24 -13
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/METADATA +53 -24
- nextrec-0.4.3.dist-info/RECORD +69 -0
- nextrec-0.4.3.dist-info/entry_points.txt +2 -0
- nextrec-0.4.1.dist-info/RECORD +0 -66
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/WHEEL +0 -0
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/layers.py
CHANGED
|
@@ -5,19 +5,21 @@ Date: create on 27/10/2025
|
|
|
5
5
|
Checkpoint: edit on 29/11/2025
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
|
-
from __future__ import annotations
|
|
9
8
|
|
|
10
|
-
from
|
|
11
|
-
from collections import OrderedDict
|
|
9
|
+
from __future__ import annotations
|
|
12
10
|
|
|
13
11
|
import torch
|
|
14
12
|
import torch.nn as nn
|
|
15
13
|
import torch.nn.functional as F
|
|
16
14
|
|
|
15
|
+
from itertools import combinations
|
|
16
|
+
from collections import OrderedDict
|
|
17
|
+
|
|
17
18
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
18
19
|
from nextrec.utils.initializer import get_initializer
|
|
19
20
|
from nextrec.basic.activation import activation_layer
|
|
20
21
|
|
|
22
|
+
|
|
21
23
|
class PredictionLayer(nn.Module):
|
|
22
24
|
def __init__(
|
|
23
25
|
self,
|
|
@@ -30,7 +32,7 @@ class PredictionLayer(nn.Module):
|
|
|
30
32
|
self.task_types = [task_type] if isinstance(task_type, str) else list(task_type)
|
|
31
33
|
if len(self.task_types) == 0:
|
|
32
34
|
raise ValueError("At least one task_type must be specified.")
|
|
33
|
-
|
|
35
|
+
|
|
34
36
|
if task_dims is None:
|
|
35
37
|
dims = [1] * len(self.task_types)
|
|
36
38
|
elif isinstance(task_dims, int):
|
|
@@ -38,7 +40,9 @@ class PredictionLayer(nn.Module):
|
|
|
38
40
|
else:
|
|
39
41
|
dims = list(task_dims)
|
|
40
42
|
if len(dims) not in (1, len(self.task_types)):
|
|
41
|
-
raise ValueError(
|
|
43
|
+
raise ValueError(
|
|
44
|
+
"[PredictionLayer Error]: task_dims must be None, a single int (shared), or a sequence of the same length as task_type."
|
|
45
|
+
)
|
|
42
46
|
if len(dims) == 1 and len(self.task_types) > 1:
|
|
43
47
|
dims = dims * len(self.task_types)
|
|
44
48
|
self.task_dims = dims
|
|
@@ -47,11 +51,11 @@ class PredictionLayer(nn.Module):
|
|
|
47
51
|
|
|
48
52
|
# slice offsets per task
|
|
49
53
|
start = 0
|
|
50
|
-
self.
|
|
54
|
+
self.task_slices: list[tuple[int, int]] = []
|
|
51
55
|
for dim in self.task_dims:
|
|
52
56
|
if dim < 1:
|
|
53
57
|
raise ValueError("Each task dimension must be >= 1.")
|
|
54
|
-
self.
|
|
58
|
+
self.task_slices.append((start, start + dim))
|
|
55
59
|
start += dim
|
|
56
60
|
if use_bias:
|
|
57
61
|
self.bias = nn.Parameter(torch.zeros(self.total_dim))
|
|
@@ -62,27 +66,33 @@ class PredictionLayer(nn.Module):
|
|
|
62
66
|
if x.dim() == 1:
|
|
63
67
|
x = x.unsqueeze(0) # (1 * total_dim)
|
|
64
68
|
if x.shape[-1] != self.total_dim:
|
|
65
|
-
raise ValueError(
|
|
69
|
+
raise ValueError(
|
|
70
|
+
f"[PredictionLayer Error]: Input last dimension ({x.shape[-1]}) does not match expected total dimension ({self.total_dim})."
|
|
71
|
+
)
|
|
66
72
|
logits = x if self.bias is None else x + self.bias
|
|
67
73
|
outputs = []
|
|
68
|
-
for task_type, (start, end) in zip(self.task_types, self.
|
|
69
|
-
task_logits = logits[..., start:end]
|
|
74
|
+
for task_type, (start, end) in zip(self.task_types, self.task_slices):
|
|
75
|
+
task_logits = logits[..., start:end] # logits for the current task
|
|
70
76
|
if self.return_logits:
|
|
71
77
|
outputs.append(task_logits)
|
|
72
78
|
continue
|
|
73
79
|
task = task_type.lower()
|
|
74
|
-
if task ==
|
|
75
|
-
|
|
76
|
-
elif task ==
|
|
77
|
-
|
|
78
|
-
elif task ==
|
|
79
|
-
|
|
80
|
+
if task == "binary":
|
|
81
|
+
outputs.append(torch.sigmoid(task_logits))
|
|
82
|
+
elif task == "regression":
|
|
83
|
+
outputs.append(task_logits)
|
|
84
|
+
elif task == "multiclass":
|
|
85
|
+
outputs.append(torch.softmax(task_logits, dim=-1))
|
|
80
86
|
else:
|
|
81
|
-
raise ValueError(
|
|
82
|
-
|
|
83
|
-
|
|
87
|
+
raise ValueError(
|
|
88
|
+
f"[PredictionLayer Error]: Unsupported task_type '{task_type}'."
|
|
89
|
+
)
|
|
90
|
+
result = torch.cat(
|
|
91
|
+
outputs, dim=-1
|
|
92
|
+
) # single: (N,1), multi-task/multi-class: (N,total_dim)
|
|
84
93
|
return result
|
|
85
94
|
|
|
95
|
+
|
|
86
96
|
class EmbeddingLayer(nn.Module):
|
|
87
97
|
def __init__(self, features: list):
|
|
88
98
|
super().__init__()
|
|
@@ -96,20 +106,30 @@ class EmbeddingLayer(nn.Module):
|
|
|
96
106
|
if feature.embedding_name in self.embed_dict:
|
|
97
107
|
continue
|
|
98
108
|
if getattr(feature, "pretrained_weight", None) is not None:
|
|
99
|
-
weight = feature.pretrained_weight
|
|
100
|
-
if weight.shape != (feature.vocab_size, feature.embedding_dim):
|
|
101
|
-
raise ValueError(f"[EmbeddingLayer Error]: Pretrained weight for '{feature.embedding_name}' has shape {weight.shape}, expected ({feature.vocab_size}, {feature.embedding_dim}).")
|
|
102
|
-
embedding = nn.Embedding.from_pretrained(embeddings=weight, freeze=feature.freeze_pretrained, padding_idx=feature.padding_idx)
|
|
103
|
-
embedding.weight.requires_grad = feature.trainable and not feature.freeze_pretrained
|
|
109
|
+
weight = feature.pretrained_weight # type: ignore[assignment]
|
|
110
|
+
if weight.shape != (feature.vocab_size, feature.embedding_dim): # type: ignore[assignment]
|
|
111
|
+
raise ValueError(f"[EmbeddingLayer Error]: Pretrained weight for '{feature.embedding_name}' has shape {weight.shape}, expected ({feature.vocab_size}, {feature.embedding_dim}).") # type: ignore[assignment]
|
|
112
|
+
embedding = nn.Embedding.from_pretrained(embeddings=weight, freeze=feature.freeze_pretrained, padding_idx=feature.padding_idx) # type: ignore[assignment]
|
|
113
|
+
embedding.weight.requires_grad = feature.trainable and not feature.freeze_pretrained # type: ignore[assignment]
|
|
104
114
|
else:
|
|
105
|
-
embedding = nn.Embedding(
|
|
115
|
+
embedding = nn.Embedding(
|
|
116
|
+
num_embeddings=feature.vocab_size,
|
|
117
|
+
embedding_dim=feature.embedding_dim,
|
|
118
|
+
padding_idx=feature.padding_idx,
|
|
119
|
+
)
|
|
106
120
|
embedding.weight.requires_grad = feature.trainable
|
|
107
|
-
initialization = get_initializer(
|
|
121
|
+
initialization = get_initializer(
|
|
122
|
+
init_type=feature.init_type,
|
|
123
|
+
activation="linear",
|
|
124
|
+
param=feature.init_params,
|
|
125
|
+
)
|
|
108
126
|
initialization(embedding.weight)
|
|
109
127
|
self.embed_dict[feature.embedding_name] = embedding
|
|
110
128
|
elif isinstance(feature, DenseFeature):
|
|
111
129
|
if not feature.use_embedding:
|
|
112
|
-
self.dense_input_dims[feature.name] = max(
|
|
130
|
+
self.dense_input_dims[feature.name] = max(
|
|
131
|
+
int(getattr(feature, "input_dim", 1)), 1
|
|
132
|
+
)
|
|
113
133
|
continue
|
|
114
134
|
if feature.name in self.dense_transforms:
|
|
115
135
|
continue
|
|
@@ -121,7 +141,9 @@ class EmbeddingLayer(nn.Module):
|
|
|
121
141
|
self.dense_transforms[feature.name] = dense_linear
|
|
122
142
|
self.dense_input_dims[feature.name] = in_dim
|
|
123
143
|
else:
|
|
124
|
-
raise TypeError(
|
|
144
|
+
raise TypeError(
|
|
145
|
+
f"[EmbeddingLayer Error]: Unsupported feature type: {type(feature)}"
|
|
146
|
+
)
|
|
125
147
|
self.output_dim = self.compute_output_dim()
|
|
126
148
|
|
|
127
149
|
def forward(
|
|
@@ -153,7 +175,9 @@ class EmbeddingLayer(nn.Module):
|
|
|
153
175
|
elif feature.combiner == "concat":
|
|
154
176
|
pooling_layer = ConcatPooling()
|
|
155
177
|
else:
|
|
156
|
-
raise ValueError(
|
|
178
|
+
raise ValueError(
|
|
179
|
+
f"[EmbeddingLayer Error]: Unknown combiner for {feature.name}: {feature.combiner}"
|
|
180
|
+
)
|
|
157
181
|
feature_mask = InputMask()(x, feature, seq_input)
|
|
158
182
|
sparse_embeds.append(pooling_layer(seq_emb, feature_mask).unsqueeze(1))
|
|
159
183
|
|
|
@@ -168,9 +192,11 @@ class EmbeddingLayer(nn.Module):
|
|
|
168
192
|
if dense_embeds:
|
|
169
193
|
pieces.append(torch.cat(dense_embeds, dim=1))
|
|
170
194
|
if not pieces:
|
|
171
|
-
raise ValueError(
|
|
195
|
+
raise ValueError(
|
|
196
|
+
"[EmbeddingLayer Error]: No input features found for EmbeddingLayer."
|
|
197
|
+
)
|
|
172
198
|
return pieces[0] if len(pieces) == 1 else torch.cat(pieces, dim=1)
|
|
173
|
-
|
|
199
|
+
|
|
174
200
|
# squeeze_dim=False requires embeddings with identical last dimension
|
|
175
201
|
output_embeddings = list(sparse_embeds)
|
|
176
202
|
if dense_embeds:
|
|
@@ -178,36 +204,53 @@ class EmbeddingLayer(nn.Module):
|
|
|
178
204
|
target_dim = output_embeddings[0].shape[-1]
|
|
179
205
|
for emb in dense_embeds:
|
|
180
206
|
if emb.shape[-1] != target_dim:
|
|
181
|
-
raise ValueError(
|
|
207
|
+
raise ValueError(
|
|
208
|
+
f"[EmbeddingLayer Error]: squeeze_dim=False requires all dense feature dimensions to match the embedding dimension of sparse/sequence features ({target_dim}), but got {emb.shape[-1]}."
|
|
209
|
+
)
|
|
182
210
|
output_embeddings.extend(emb.unsqueeze(1) for emb in dense_embeds)
|
|
183
211
|
else:
|
|
184
212
|
dims = {emb.shape[-1] for emb in dense_embeds}
|
|
185
213
|
if len(dims) != 1:
|
|
186
|
-
raise ValueError(
|
|
214
|
+
raise ValueError(
|
|
215
|
+
f"[EmbeddingLayer Error]: squeeze_dim=False requires all dense features to have identical dimensions when no sparse/sequence features are present, but got dimensions {dims}."
|
|
216
|
+
)
|
|
187
217
|
output_embeddings = [emb.unsqueeze(1) for emb in dense_embeds]
|
|
188
218
|
if not output_embeddings:
|
|
189
|
-
raise ValueError(
|
|
219
|
+
raise ValueError(
|
|
220
|
+
"[EmbeddingLayer Error]: squeeze_dim=False requires at least one sparse/sequence feature or dense features with identical projected dimensions."
|
|
221
|
+
)
|
|
190
222
|
return torch.cat(output_embeddings, dim=1)
|
|
191
223
|
|
|
192
|
-
def project_dense(
|
|
224
|
+
def project_dense(
|
|
225
|
+
self, feature: DenseFeature, x: dict[str, torch.Tensor]
|
|
226
|
+
) -> torch.Tensor:
|
|
193
227
|
if feature.name not in x:
|
|
194
|
-
raise KeyError(
|
|
228
|
+
raise KeyError(
|
|
229
|
+
f"[EmbeddingLayer Error]:Dense feature '{feature.name}' is missing from input."
|
|
230
|
+
)
|
|
195
231
|
value = x[feature.name].float()
|
|
196
232
|
if value.dim() == 1:
|
|
197
233
|
value = value.unsqueeze(-1)
|
|
198
234
|
else:
|
|
199
235
|
value = value.view(value.size(0), -1)
|
|
200
|
-
expected_in_dim = self.dense_input_dims.get(
|
|
236
|
+
expected_in_dim = self.dense_input_dims.get(
|
|
237
|
+
feature.name, max(int(getattr(feature, "input_dim", 1)), 1)
|
|
238
|
+
)
|
|
201
239
|
if value.shape[1] != expected_in_dim:
|
|
202
|
-
raise ValueError(
|
|
240
|
+
raise ValueError(
|
|
241
|
+
f"[EmbeddingLayer Error]:Dense feature '{feature.name}' expects {expected_in_dim} inputs but got {value.shape[1]}."
|
|
242
|
+
)
|
|
203
243
|
if not feature.use_embedding:
|
|
204
|
-
return value
|
|
244
|
+
return value
|
|
205
245
|
dense_layer = self.dense_transforms[feature.name]
|
|
206
246
|
return dense_layer(value)
|
|
207
247
|
|
|
208
|
-
def compute_output_dim(
|
|
248
|
+
def compute_output_dim(
|
|
249
|
+
self,
|
|
250
|
+
features: list[DenseFeature | SequenceFeature | SparseFeature] | None = None,
|
|
251
|
+
) -> int:
|
|
209
252
|
candidates = list(features) if features is not None else self.features
|
|
210
|
-
unique_feats = OrderedDict((feat.name, feat) for feat in candidates)
|
|
253
|
+
unique_feats = OrderedDict((feat.name, feat) for feat in candidates) # type: ignore[assignment]
|
|
211
254
|
dim = 0
|
|
212
255
|
for feat in unique_feats.values():
|
|
213
256
|
if isinstance(feat, DenseFeature):
|
|
@@ -218,28 +261,34 @@ class EmbeddingLayer(nn.Module):
|
|
|
218
261
|
elif isinstance(feat, SequenceFeature) and feat.combiner == "concat":
|
|
219
262
|
dim += feat.embedding_dim * feat.max_len
|
|
220
263
|
else:
|
|
221
|
-
dim += feat.embedding_dim
|
|
264
|
+
dim += feat.embedding_dim # type: ignore[assignment]
|
|
222
265
|
return dim
|
|
223
266
|
|
|
224
267
|
def get_input_dim(self, features: list[object] | None = None) -> int:
|
|
225
|
-
return self.compute_output_dim(features)
|
|
268
|
+
return self.compute_output_dim(features) # type: ignore[assignment]
|
|
226
269
|
|
|
227
270
|
@property
|
|
228
271
|
def input_dim(self) -> int:
|
|
229
272
|
return self.output_dim
|
|
230
273
|
|
|
274
|
+
|
|
231
275
|
class InputMask(nn.Module):
|
|
232
276
|
def __init__(self):
|
|
233
277
|
super().__init__()
|
|
234
278
|
|
|
235
|
-
def forward(
|
|
279
|
+
def forward(
|
|
280
|
+
self,
|
|
281
|
+
x: dict[str, torch.Tensor],
|
|
282
|
+
feature: SequenceFeature,
|
|
283
|
+
seq_tensor: torch.Tensor | None = None,
|
|
284
|
+
):
|
|
236
285
|
if seq_tensor is not None:
|
|
237
286
|
values = seq_tensor
|
|
238
287
|
else:
|
|
239
288
|
values = x[feature.name]
|
|
240
289
|
values = values.long()
|
|
241
290
|
padding_idx = feature.padding_idx if feature.padding_idx is not None else 0
|
|
242
|
-
mask =
|
|
291
|
+
mask = values != padding_idx
|
|
243
292
|
|
|
244
293
|
if mask.dim() == 1:
|
|
245
294
|
# [B] -> [B, 1, 1]
|
|
@@ -253,14 +302,14 @@ class InputMask(nn.Module):
|
|
|
253
302
|
if mask.size(1) != 1 and mask.size(2) == 1:
|
|
254
303
|
mask = mask.squeeze(-1).unsqueeze(1)
|
|
255
304
|
else:
|
|
256
|
-
raise ValueError(
|
|
305
|
+
raise ValueError(
|
|
306
|
+
f"InputMask only supports 1D/2D/3D tensors, got shape {values.shape}"
|
|
307
|
+
)
|
|
257
308
|
return mask.float()
|
|
258
309
|
|
|
310
|
+
|
|
259
311
|
class LR(nn.Module):
|
|
260
|
-
def __init__(
|
|
261
|
-
self,
|
|
262
|
-
input_dim: int,
|
|
263
|
-
sigmoid: bool = False):
|
|
312
|
+
def __init__(self, input_dim: int, sigmoid: bool = False):
|
|
264
313
|
super().__init__()
|
|
265
314
|
self.sigmoid = sigmoid
|
|
266
315
|
self.fc = nn.Linear(input_dim, 1, bias=True)
|
|
@@ -271,18 +320,24 @@ class LR(nn.Module):
|
|
|
271
320
|
else:
|
|
272
321
|
return self.fc(x)
|
|
273
322
|
|
|
323
|
+
|
|
274
324
|
class ConcatPooling(nn.Module):
|
|
275
325
|
def __init__(self):
|
|
276
326
|
super().__init__()
|
|
277
327
|
|
|
278
|
-
def forward(
|
|
279
|
-
|
|
328
|
+
def forward(
|
|
329
|
+
self, x: torch.Tensor, mask: torch.Tensor | None = None
|
|
330
|
+
) -> torch.Tensor:
|
|
331
|
+
return x.flatten(start_dim=1, end_dim=2)
|
|
332
|
+
|
|
280
333
|
|
|
281
334
|
class AveragePooling(nn.Module):
|
|
282
335
|
def __init__(self):
|
|
283
336
|
super().__init__()
|
|
284
337
|
|
|
285
|
-
def forward(
|
|
338
|
+
def forward(
|
|
339
|
+
self, x: torch.Tensor, mask: torch.Tensor | None = None
|
|
340
|
+
) -> torch.Tensor:
|
|
286
341
|
if mask is None:
|
|
287
342
|
return torch.mean(x, dim=1)
|
|
288
343
|
else:
|
|
@@ -290,59 +345,76 @@ class AveragePooling(nn.Module):
|
|
|
290
345
|
non_padding_length = mask.sum(dim=-1)
|
|
291
346
|
return sum_pooling_matrix / (non_padding_length.float() + 1e-16)
|
|
292
347
|
|
|
348
|
+
|
|
293
349
|
class SumPooling(nn.Module):
|
|
294
350
|
def __init__(self):
|
|
295
351
|
super().__init__()
|
|
296
352
|
|
|
297
|
-
def forward(
|
|
353
|
+
def forward(
|
|
354
|
+
self, x: torch.Tensor, mask: torch.Tensor | None = None
|
|
355
|
+
) -> torch.Tensor:
|
|
298
356
|
if mask is None:
|
|
299
357
|
return torch.sum(x, dim=1)
|
|
300
358
|
else:
|
|
301
359
|
return torch.bmm(mask, x).squeeze(1)
|
|
302
360
|
|
|
361
|
+
|
|
303
362
|
class MLP(nn.Module):
|
|
304
363
|
def __init__(
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
364
|
+
self,
|
|
365
|
+
input_dim: int,
|
|
366
|
+
output_layer: bool = True,
|
|
367
|
+
dims: list[int] | None = None,
|
|
368
|
+
dropout: float = 0.0,
|
|
369
|
+
activation: str = "relu",
|
|
370
|
+
use_norm: bool = True,
|
|
371
|
+
norm_type: str = "layer_norm",
|
|
372
|
+
):
|
|
311
373
|
super().__init__()
|
|
312
374
|
if dims is None:
|
|
313
375
|
dims = []
|
|
314
376
|
layers = []
|
|
315
377
|
current_dim = input_dim
|
|
316
|
-
|
|
317
378
|
for i_dim in dims:
|
|
318
379
|
layers.append(nn.Linear(current_dim, i_dim))
|
|
319
|
-
|
|
380
|
+
if use_norm:
|
|
381
|
+
if norm_type == "batch_norm":
|
|
382
|
+
# **IMPORTANT** be careful when using BatchNorm1d in distributed training, nextrec does not support sync batch norm now
|
|
383
|
+
layers.append(nn.BatchNorm1d(i_dim))
|
|
384
|
+
elif norm_type == "layer_norm":
|
|
385
|
+
layers.append(nn.LayerNorm(i_dim))
|
|
386
|
+
else:
|
|
387
|
+
raise ValueError(f"Unsupported norm_type: {norm_type}")
|
|
388
|
+
|
|
320
389
|
layers.append(activation_layer(activation))
|
|
321
390
|
layers.append(nn.Dropout(p=dropout))
|
|
322
391
|
current_dim = i_dim
|
|
323
|
-
|
|
392
|
+
# output layer
|
|
324
393
|
if output_layer:
|
|
325
394
|
layers.append(nn.Linear(current_dim, 1))
|
|
326
395
|
self.output_dim = 1
|
|
327
396
|
else:
|
|
328
|
-
self.output_dim = current_dim
|
|
397
|
+
self.output_dim = current_dim
|
|
329
398
|
self.mlp = nn.Sequential(*layers)
|
|
399
|
+
|
|
330
400
|
def forward(self, x):
|
|
331
401
|
return self.mlp(x)
|
|
332
|
-
|
|
402
|
+
|
|
403
|
+
|
|
333
404
|
class FM(nn.Module):
|
|
334
405
|
def __init__(self, reduce_sum: bool = True):
|
|
335
406
|
super().__init__()
|
|
336
407
|
self.reduce_sum = reduce_sum
|
|
337
408
|
|
|
338
409
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
339
|
-
square_of_sum = torch.sum(x, dim=1)**2
|
|
410
|
+
square_of_sum = torch.sum(x, dim=1) ** 2
|
|
340
411
|
sum_of_square = torch.sum(x**2, dim=1)
|
|
341
412
|
ix = square_of_sum - sum_of_square
|
|
342
413
|
if self.reduce_sum:
|
|
343
414
|
ix = torch.sum(ix, dim=1, keepdim=True)
|
|
344
415
|
return 0.5 * ix
|
|
345
416
|
|
|
417
|
+
|
|
346
418
|
class CrossLayer(nn.Module):
|
|
347
419
|
def __init__(self, input_dim: int):
|
|
348
420
|
super(CrossLayer, self).__init__()
|
|
@@ -353,60 +425,89 @@ class CrossLayer(nn.Module):
|
|
|
353
425
|
x = self.w(x_i) * x_0 + self.b
|
|
354
426
|
return x
|
|
355
427
|
|
|
428
|
+
|
|
356
429
|
class SENETLayer(nn.Module):
|
|
357
|
-
def __init__(
|
|
358
|
-
self,
|
|
359
|
-
num_fields: int,
|
|
360
|
-
reduction_ratio: int = 3):
|
|
430
|
+
def __init__(self, num_fields: int, reduction_ratio: int = 3):
|
|
361
431
|
super(SENETLayer, self).__init__()
|
|
362
|
-
reduced_size = max(1, int(num_fields/ reduction_ratio))
|
|
363
|
-
self.mlp = nn.Sequential(
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
432
|
+
reduced_size = max(1, int(num_fields / reduction_ratio))
|
|
433
|
+
self.mlp = nn.Sequential(
|
|
434
|
+
nn.Linear(num_fields, reduced_size, bias=False),
|
|
435
|
+
nn.ReLU(),
|
|
436
|
+
nn.Linear(reduced_size, num_fields, bias=False),
|
|
437
|
+
nn.ReLU(),
|
|
438
|
+
)
|
|
439
|
+
|
|
367
440
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
368
441
|
z = torch.mean(x, dim=-1, out=None)
|
|
369
442
|
a = self.mlp(z)
|
|
370
|
-
v = x*a.unsqueeze(-1)
|
|
443
|
+
v = x * a.unsqueeze(-1)
|
|
371
444
|
return v
|
|
372
445
|
|
|
446
|
+
|
|
373
447
|
class BiLinearInteractionLayer(nn.Module):
|
|
374
448
|
def __init__(
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
num_fields: int,
|
|
378
|
-
bilinear_type: str = "field_interaction"):
|
|
449
|
+
self, input_dim: int, num_fields: int, bilinear_type: str = "field_interaction"
|
|
450
|
+
):
|
|
379
451
|
super(BiLinearInteractionLayer, self).__init__()
|
|
380
452
|
self.bilinear_type = bilinear_type
|
|
381
453
|
if self.bilinear_type == "field_all":
|
|
382
454
|
self.bilinear_layer = nn.Linear(input_dim, input_dim, bias=False)
|
|
383
455
|
elif self.bilinear_type == "field_each":
|
|
384
|
-
self.bilinear_layer = nn.ModuleList(
|
|
456
|
+
self.bilinear_layer = nn.ModuleList(
|
|
457
|
+
[nn.Linear(input_dim, input_dim, bias=False) for i in range(num_fields)]
|
|
458
|
+
)
|
|
385
459
|
elif self.bilinear_type == "field_interaction":
|
|
386
|
-
self.bilinear_layer = nn.ModuleList(
|
|
460
|
+
self.bilinear_layer = nn.ModuleList(
|
|
461
|
+
[
|
|
462
|
+
nn.Linear(input_dim, input_dim, bias=False)
|
|
463
|
+
for i, j in combinations(range(num_fields), 2)
|
|
464
|
+
]
|
|
465
|
+
)
|
|
387
466
|
else:
|
|
388
467
|
raise NotImplementedError()
|
|
389
468
|
|
|
390
469
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
391
470
|
feature_emb = torch.split(x, 1, dim=1)
|
|
392
471
|
if self.bilinear_type == "field_all":
|
|
393
|
-
bilinear_list = [
|
|
472
|
+
bilinear_list = [
|
|
473
|
+
self.bilinear_layer(v_i) * v_j
|
|
474
|
+
for v_i, v_j in combinations(feature_emb, 2)
|
|
475
|
+
]
|
|
394
476
|
elif self.bilinear_type == "field_each":
|
|
395
|
-
bilinear_list = [self.bilinear_layer[i](feature_emb[i])*feature_emb[j] for i,j in combinations(range(len(feature_emb)), 2)] # type: ignore[assignment]
|
|
477
|
+
bilinear_list = [self.bilinear_layer[i](feature_emb[i]) * feature_emb[j] for i, j in combinations(range(len(feature_emb)), 2)] # type: ignore[assignment]
|
|
396
478
|
elif self.bilinear_type == "field_interaction":
|
|
397
|
-
bilinear_list = [self.bilinear_layer[i](v[0])*v[1] for i,v in enumerate(combinations(feature_emb, 2))]
|
|
479
|
+
bilinear_list = [self.bilinear_layer[i](v[0]) * v[1] for i, v in enumerate(combinations(feature_emb, 2))] # type: ignore[assignment]
|
|
398
480
|
return torch.cat(bilinear_list, dim=1)
|
|
399
481
|
|
|
482
|
+
|
|
483
|
+
class HadamardInteractionLayer(nn.Module):
|
|
484
|
+
"""Hadamard interaction layer for Deep-FiBiNET (0 case in 01/11)."""
|
|
485
|
+
|
|
486
|
+
def __init__(self, num_fields: int):
|
|
487
|
+
super().__init__()
|
|
488
|
+
self.num_fields = num_fields
|
|
489
|
+
|
|
490
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
491
|
+
# x: [B, F, D]
|
|
492
|
+
feature_emb = torch.split(x, 1, dim=1) # list of F tensors [B,1,D]
|
|
493
|
+
|
|
494
|
+
hadamard_list = [v_i * v_j for (v_i, v_j) in combinations(feature_emb, 2)]
|
|
495
|
+
return torch.cat(hadamard_list, dim=1) # [B, num_pairs, D]
|
|
496
|
+
|
|
497
|
+
|
|
400
498
|
class MultiHeadSelfAttention(nn.Module):
|
|
401
499
|
def __init__(
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
500
|
+
self,
|
|
501
|
+
embedding_dim: int,
|
|
502
|
+
num_heads: int = 2,
|
|
503
|
+
dropout: float = 0.0,
|
|
504
|
+
use_residual: bool = True,
|
|
505
|
+
):
|
|
407
506
|
super().__init__()
|
|
408
507
|
if embedding_dim % num_heads != 0:
|
|
409
|
-
raise ValueError(
|
|
508
|
+
raise ValueError(
|
|
509
|
+
f"[MultiHeadSelfAttention Error]: embedding_dim ({embedding_dim}) must be divisible by num_heads ({num_heads})"
|
|
510
|
+
)
|
|
410
511
|
self.embedding_dim = embedding_dim
|
|
411
512
|
self.num_heads = num_heads
|
|
412
513
|
self.head_dim = embedding_dim // num_heads
|
|
@@ -417,24 +518,34 @@ class MultiHeadSelfAttention(nn.Module):
|
|
|
417
518
|
if self.use_residual:
|
|
418
519
|
self.W_Res = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
419
520
|
self.dropout = nn.Dropout(dropout)
|
|
420
|
-
|
|
521
|
+
|
|
421
522
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
422
523
|
batch_size, num_fields, _ = x.shape
|
|
423
524
|
Q = self.W_Q(x) # [batch_size, num_fields, embedding_dim]
|
|
424
525
|
K = self.W_K(x)
|
|
425
526
|
V = self.W_V(x)
|
|
426
527
|
# Split into multiple heads: [batch_size, num_heads, num_fields, head_dim]
|
|
427
|
-
Q = Q.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
|
|
428
|
-
|
|
429
|
-
|
|
528
|
+
Q = Q.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
|
|
529
|
+
1, 2
|
|
530
|
+
)
|
|
531
|
+
K = K.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
|
|
532
|
+
1, 2
|
|
533
|
+
)
|
|
534
|
+
V = V.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
|
|
535
|
+
1, 2
|
|
536
|
+
)
|
|
430
537
|
# Attention scores
|
|
431
|
-
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim
|
|
538
|
+
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim**0.5)
|
|
432
539
|
attention_weights = F.softmax(scores, dim=-1)
|
|
433
540
|
attention_weights = self.dropout(attention_weights)
|
|
434
|
-
attention_output = torch.matmul(
|
|
541
|
+
attention_output = torch.matmul(
|
|
542
|
+
attention_weights, V
|
|
543
|
+
) # [batch_size, num_heads, num_fields, head_dim]
|
|
435
544
|
# Concatenate heads
|
|
436
545
|
attention_output = attention_output.transpose(1, 2).contiguous()
|
|
437
|
-
attention_output = attention_output.view(
|
|
546
|
+
attention_output = attention_output.view(
|
|
547
|
+
batch_size, num_fields, self.embedding_dim
|
|
548
|
+
)
|
|
438
549
|
# Residual connection
|
|
439
550
|
if self.use_residual:
|
|
440
551
|
output = attention_output + self.W_Res(x)
|
|
@@ -443,17 +554,20 @@ class MultiHeadSelfAttention(nn.Module):
|
|
|
443
554
|
output = F.relu(output)
|
|
444
555
|
return output
|
|
445
556
|
|
|
557
|
+
|
|
446
558
|
class AttentionPoolingLayer(nn.Module):
|
|
447
559
|
"""
|
|
448
560
|
Attention pooling layer for DIN/DIEN
|
|
449
561
|
Computes attention weights between query (candidate item) and keys (user behavior sequence)
|
|
450
562
|
"""
|
|
563
|
+
|
|
451
564
|
def __init__(
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
565
|
+
self,
|
|
566
|
+
embedding_dim: int,
|
|
567
|
+
hidden_units: list = [80, 40],
|
|
568
|
+
activation: str = "sigmoid",
|
|
569
|
+
use_softmax: bool = False,
|
|
570
|
+
):
|
|
457
571
|
super().__init__()
|
|
458
572
|
self.embedding_dim = embedding_dim
|
|
459
573
|
self.use_softmax = use_softmax
|
|
@@ -463,12 +577,18 @@ class AttentionPoolingLayer(nn.Module):
|
|
|
463
577
|
layers = []
|
|
464
578
|
for hidden_unit in hidden_units:
|
|
465
579
|
layers.append(nn.Linear(input_dim, hidden_unit))
|
|
466
|
-
layers.append(activation_layer(activation))
|
|
580
|
+
layers.append(activation_layer(activation, emb_size=hidden_unit))
|
|
467
581
|
input_dim = hidden_unit
|
|
468
582
|
layers.append(nn.Linear(input_dim, 1))
|
|
469
583
|
self.attention_net = nn.Sequential(*layers)
|
|
470
|
-
|
|
471
|
-
def forward(
|
|
584
|
+
|
|
585
|
+
def forward(
|
|
586
|
+
self,
|
|
587
|
+
query: torch.Tensor,
|
|
588
|
+
keys: torch.Tensor,
|
|
589
|
+
keys_length: torch.Tensor | None = None,
|
|
590
|
+
mask: torch.Tensor | None = None,
|
|
591
|
+
):
|
|
472
592
|
"""
|
|
473
593
|
Args:
|
|
474
594
|
query: [batch_size, embedding_dim] - candidate item embedding
|
|
@@ -479,28 +599,46 @@ class AttentionPoolingLayer(nn.Module):
|
|
|
479
599
|
output: [batch_size, embedding_dim] - attention pooled representation
|
|
480
600
|
"""
|
|
481
601
|
batch_size, sequence_length, embedding_dim = keys.shape
|
|
482
|
-
assert query.shape == (
|
|
602
|
+
assert query.shape == (
|
|
603
|
+
batch_size,
|
|
604
|
+
embedding_dim,
|
|
605
|
+
), f"query shape {query.shape} != ({batch_size}, {embedding_dim})"
|
|
483
606
|
if mask is None and keys_length is not None:
|
|
484
607
|
# keys_length: (batch_size,)
|
|
485
608
|
device = keys.device
|
|
486
|
-
seq_range = torch.arange(sequence_length, device=device).unsqueeze(
|
|
609
|
+
seq_range = torch.arange(sequence_length, device=device).unsqueeze(
|
|
610
|
+
0
|
|
611
|
+
) # (1, sequence_length)
|
|
487
612
|
mask = (seq_range < keys_length.unsqueeze(1)).unsqueeze(-1).float()
|
|
488
613
|
if mask is not None:
|
|
489
614
|
if mask.dim() == 2:
|
|
490
615
|
# (B, L)
|
|
491
616
|
mask = mask.unsqueeze(-1)
|
|
492
|
-
elif
|
|
617
|
+
elif (
|
|
618
|
+
mask.dim() == 3
|
|
619
|
+
and mask.shape[1] == 1
|
|
620
|
+
and mask.shape[2] == sequence_length
|
|
621
|
+
):
|
|
493
622
|
# (B, 1, L) -> (B, L, 1)
|
|
494
623
|
mask = mask.transpose(1, 2)
|
|
495
|
-
elif
|
|
624
|
+
elif (
|
|
625
|
+
mask.dim() == 3
|
|
626
|
+
and mask.shape[1] == sequence_length
|
|
627
|
+
and mask.shape[2] == 1
|
|
628
|
+
):
|
|
496
629
|
pass
|
|
497
630
|
else:
|
|
498
|
-
raise ValueError(
|
|
631
|
+
raise ValueError(
|
|
632
|
+
f"[AttentionPoolingLayer Error]: Unsupported mask shape: {mask.shape}"
|
|
633
|
+
)
|
|
499
634
|
mask = mask.to(keys.dtype)
|
|
500
635
|
# Expand query to (B, L, D)
|
|
501
636
|
query_expanded = query.unsqueeze(1).expand(-1, sequence_length, -1)
|
|
502
637
|
# [query, key, query-key, query*key] -> (B, L, 4D)
|
|
503
|
-
attention_input = torch.cat(
|
|
638
|
+
attention_input = torch.cat(
|
|
639
|
+
[query_expanded, keys, query_expanded - keys, query_expanded * keys],
|
|
640
|
+
dim=-1,
|
|
641
|
+
)
|
|
504
642
|
attention_scores = self.attention_net(attention_input)
|
|
505
643
|
if mask is not None:
|
|
506
644
|
attention_scores = attention_scores.masked_fill(mask == 0, -1e9)
|