nextrec 0.4.20__py3-none-any.whl → 0.4.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +9 -4
- nextrec/basic/callback.py +39 -87
- nextrec/basic/features.py +149 -28
- nextrec/basic/heads.py +3 -1
- nextrec/basic/layers.py +375 -94
- nextrec/basic/loggers.py +236 -39
- nextrec/basic/model.py +259 -326
- nextrec/basic/session.py +2 -2
- nextrec/basic/summary.py +323 -0
- nextrec/cli.py +3 -3
- nextrec/data/data_processing.py +45 -1
- nextrec/data/dataloader.py +2 -2
- nextrec/data/preprocessor.py +2 -2
- nextrec/loss/__init__.py +0 -4
- nextrec/loss/grad_norm.py +3 -3
- nextrec/models/multi_task/esmm.py +4 -6
- nextrec/models/multi_task/mmoe.py +4 -6
- nextrec/models/multi_task/ple.py +6 -8
- nextrec/models/multi_task/poso.py +5 -7
- nextrec/models/multi_task/share_bottom.py +6 -8
- nextrec/models/ranking/afm.py +4 -6
- nextrec/models/ranking/autoint.py +4 -6
- nextrec/models/ranking/dcn.py +8 -7
- nextrec/models/ranking/dcn_v2.py +4 -6
- nextrec/models/ranking/deepfm.py +5 -7
- nextrec/models/ranking/dien.py +8 -7
- nextrec/models/ranking/din.py +8 -7
- nextrec/models/ranking/eulernet.py +5 -7
- nextrec/models/ranking/ffm.py +5 -7
- nextrec/models/ranking/fibinet.py +4 -6
- nextrec/models/ranking/fm.py +4 -6
- nextrec/models/ranking/lr.py +4 -6
- nextrec/models/ranking/masknet.py +8 -9
- nextrec/models/ranking/pnn.py +4 -6
- nextrec/models/ranking/widedeep.py +5 -7
- nextrec/models/ranking/xdeepfm.py +8 -7
- nextrec/models/retrieval/dssm.py +4 -10
- nextrec/models/retrieval/dssm_v2.py +0 -6
- nextrec/models/retrieval/mind.py +4 -10
- nextrec/models/retrieval/sdm.py +4 -10
- nextrec/models/retrieval/youtube_dnn.py +4 -10
- nextrec/models/sequential/hstu.py +1 -3
- nextrec/utils/__init__.py +17 -15
- nextrec/utils/config.py +15 -5
- nextrec/utils/console.py +2 -2
- nextrec/utils/feature.py +2 -2
- nextrec/{loss/loss_utils.py → utils/loss.py} +21 -36
- nextrec/utils/torch_utils.py +57 -112
- nextrec/utils/types.py +63 -0
- {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/METADATA +8 -6
- nextrec-0.4.22.dist-info/RECORD +81 -0
- nextrec-0.4.20.dist-info/RECORD +0 -79
- {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/WHEEL +0 -0
- {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/layers.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Layer implementations used across NextRec
|
|
2
|
+
Layer implementations used across NextRec.
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 27/12/2025
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -10,7 +10,9 @@ from __future__ import annotations
|
|
|
10
10
|
|
|
11
11
|
from collections import OrderedDict
|
|
12
12
|
from itertools import combinations
|
|
13
|
+
from typing import Literal
|
|
13
14
|
|
|
15
|
+
import math
|
|
14
16
|
import torch
|
|
15
17
|
import torch.nn as nn
|
|
16
18
|
import torch.nn.functional as F
|
|
@@ -23,7 +25,9 @@ from nextrec.utils.torch_utils import get_initializer
|
|
|
23
25
|
class PredictionLayer(nn.Module):
|
|
24
26
|
def __init__(
|
|
25
27
|
self,
|
|
26
|
-
task_type:
|
|
28
|
+
task_type: (
|
|
29
|
+
Literal["binary", "regression"] | list[Literal["binary", "regression"]]
|
|
30
|
+
) = "binary",
|
|
27
31
|
task_dims: int | list[int] | None = None,
|
|
28
32
|
use_bias: bool = True,
|
|
29
33
|
return_logits: bool = False,
|
|
@@ -51,7 +55,8 @@ class PredictionLayer(nn.Module):
|
|
|
51
55
|
dims = list(task_dims)
|
|
52
56
|
if len(dims) not in (1, len(self.task_types)):
|
|
53
57
|
raise ValueError(
|
|
54
|
-
"[PredictionLayer Error]: task_dims must be None, a single int (shared),
|
|
58
|
+
"[PredictionLayer Error]: task_dims must be None, a single int (shared), "
|
|
59
|
+
"or a sequence of the same length as task_type."
|
|
55
60
|
)
|
|
56
61
|
if len(dims) == 1 and len(self.task_types) > 1:
|
|
57
62
|
dims = dims * len(self.task_types)
|
|
@@ -61,7 +66,7 @@ class PredictionLayer(nn.Module):
|
|
|
61
66
|
|
|
62
67
|
# slice offsets per task
|
|
63
68
|
start = 0
|
|
64
|
-
self.task_slices
|
|
69
|
+
self.task_slices = []
|
|
65
70
|
for dim in self.task_dims:
|
|
66
71
|
if dim < 1:
|
|
67
72
|
raise ValueError("Each task dimension must be >= 1.")
|
|
@@ -106,53 +111,96 @@ class EmbeddingLayer(nn.Module):
|
|
|
106
111
|
super().__init__()
|
|
107
112
|
self.features = list(features)
|
|
108
113
|
self.embed_dict = nn.ModuleDict()
|
|
109
|
-
self.dense_transforms = nn.ModuleDict()
|
|
110
|
-
self.dense_input_dims
|
|
114
|
+
self.dense_transforms = nn.ModuleDict() # dense feature projection layers
|
|
115
|
+
self.dense_input_dims = {}
|
|
116
|
+
self.sequence_poolings = nn.ModuleDict()
|
|
111
117
|
|
|
112
118
|
for feature in self.features:
|
|
113
119
|
if isinstance(feature, (SparseFeature, SequenceFeature)):
|
|
114
|
-
if feature.embedding_name in self.embed_dict:
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
120
|
+
if feature.embedding_name not in self.embed_dict:
|
|
121
|
+
if feature.pretrained_weight is not None:
|
|
122
|
+
weight = feature.pretrained_weight
|
|
123
|
+
if weight.shape != (
|
|
124
|
+
feature.vocab_size,
|
|
125
|
+
feature.embedding_dim,
|
|
126
|
+
):
|
|
127
|
+
raise ValueError(
|
|
128
|
+
f"[EmbeddingLayer Error]: Pretrained weight for '{feature.embedding_name}' has shape {weight.shape}, expected ({feature.vocab_size}, {feature.embedding_dim})."
|
|
129
|
+
)
|
|
130
|
+
embedding = nn.Embedding.from_pretrained(
|
|
131
|
+
embeddings=weight,
|
|
132
|
+
freeze=feature.freeze_pretrained,
|
|
133
|
+
padding_idx=feature.padding_idx,
|
|
134
|
+
)
|
|
135
|
+
embedding.weight.requires_grad = (
|
|
136
|
+
feature.trainable and not feature.freeze_pretrained
|
|
137
|
+
)
|
|
138
|
+
else:
|
|
139
|
+
embedding = nn.Embedding(
|
|
140
|
+
num_embeddings=feature.vocab_size,
|
|
141
|
+
embedding_dim=feature.embedding_dim,
|
|
142
|
+
padding_idx=feature.padding_idx,
|
|
143
|
+
)
|
|
144
|
+
embedding.weight.requires_grad = feature.trainable
|
|
145
|
+
initialization = get_initializer(
|
|
146
|
+
init_type=feature.init_type, # type: ignore[arg-type]
|
|
147
|
+
activation="linear",
|
|
148
|
+
param=feature.init_params,
|
|
149
|
+
)
|
|
150
|
+
initialization(embedding.weight)
|
|
151
|
+
self.embed_dict[feature.embedding_name] = embedding
|
|
152
|
+
|
|
136
153
|
elif isinstance(feature, DenseFeature):
|
|
137
|
-
if not feature.
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
continue
|
|
154
|
+
if not feature.use_projection:
|
|
155
|
+
input_dim = feature.input_dim
|
|
156
|
+
self.dense_input_dims[feature.name] = max(int(input_dim), 1)
|
|
157
|
+
continue # skip if no projection is needed
|
|
142
158
|
if feature.name in self.dense_transforms:
|
|
143
|
-
continue
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
159
|
+
continue # skip if already created
|
|
160
|
+
|
|
161
|
+
input_dim = feature.input_dim
|
|
162
|
+
out_dim = (
|
|
163
|
+
feature.embedding_dim
|
|
164
|
+
if feature.embedding_dim is not None
|
|
165
|
+
else input_dim
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
dense_linear = nn.Linear(input_dim, out_dim, bias=True)
|
|
147
169
|
nn.init.xavier_uniform_(dense_linear.weight)
|
|
148
170
|
nn.init.zeros_(dense_linear.bias)
|
|
149
171
|
self.dense_transforms[feature.name] = dense_linear
|
|
150
|
-
self.dense_input_dims[feature.name] =
|
|
172
|
+
self.dense_input_dims[feature.name] = input_dim
|
|
151
173
|
else:
|
|
152
174
|
raise TypeError(
|
|
153
175
|
f"[EmbeddingLayer Error]: Unsupported feature type: {type(feature)}"
|
|
154
176
|
)
|
|
155
|
-
|
|
177
|
+
if isinstance(feature, SequenceFeature):
|
|
178
|
+
if feature.name in self.sequence_poolings:
|
|
179
|
+
continue
|
|
180
|
+
if feature.combiner == "mean":
|
|
181
|
+
pooling_layer = AveragePooling()
|
|
182
|
+
elif feature.combiner == "sum":
|
|
183
|
+
pooling_layer = SumPooling()
|
|
184
|
+
elif feature.combiner == "concat":
|
|
185
|
+
pooling_layer = ConcatPooling()
|
|
186
|
+
elif feature.combiner == "dot_attention":
|
|
187
|
+
pooling_layer = DotProductAttentionPooling(feature.embedding_dim)
|
|
188
|
+
elif feature.combiner == "self_attention":
|
|
189
|
+
if feature.embedding_dim % 4 != 0:
|
|
190
|
+
raise ValueError(
|
|
191
|
+
f"[EmbeddingLayer Error]: self_attention requires embedding_dim divisible by 4, got {feature.embedding_dim}."
|
|
192
|
+
)
|
|
193
|
+
pooling_layer = SelfAttentionPooling(
|
|
194
|
+
feature.embedding_dim, num_heads=4, dropout=0.0
|
|
195
|
+
)
|
|
196
|
+
else:
|
|
197
|
+
raise ValueError(
|
|
198
|
+
f"[EmbeddingLayer Error]: Unknown combiner for {feature.name}: {feature.combiner}"
|
|
199
|
+
)
|
|
200
|
+
self.sequence_poolings[feature.name] = pooling_layer
|
|
201
|
+
self.output_dim = (
|
|
202
|
+
self.compute_output_dim()
|
|
203
|
+
) # output dimension of the embedding layer
|
|
156
204
|
|
|
157
205
|
def forward(
|
|
158
206
|
self,
|
|
@@ -160,8 +208,8 @@ class EmbeddingLayer(nn.Module):
|
|
|
160
208
|
features: list[object],
|
|
161
209
|
squeeze_dim: bool = False,
|
|
162
210
|
) -> torch.Tensor:
|
|
163
|
-
sparse_embeds
|
|
164
|
-
dense_embeds
|
|
211
|
+
sparse_embeds = []
|
|
212
|
+
dense_embeds = []
|
|
165
213
|
|
|
166
214
|
for feature in features:
|
|
167
215
|
if isinstance(feature, SparseFeature):
|
|
@@ -175,17 +223,7 @@ class EmbeddingLayer(nn.Module):
|
|
|
175
223
|
|
|
176
224
|
embed = self.embed_dict[feature.embedding_name]
|
|
177
225
|
seq_emb = embed(seq_input) # [B, seq_len, emb_dim]
|
|
178
|
-
|
|
179
|
-
if feature.combiner == "mean":
|
|
180
|
-
pooling_layer = AveragePooling()
|
|
181
|
-
elif feature.combiner == "sum":
|
|
182
|
-
pooling_layer = SumPooling()
|
|
183
|
-
elif feature.combiner == "concat":
|
|
184
|
-
pooling_layer = ConcatPooling()
|
|
185
|
-
else:
|
|
186
|
-
raise ValueError(
|
|
187
|
-
f"[EmbeddingLayer Error]: Unknown combiner for {feature.name}: {feature.combiner}"
|
|
188
|
-
)
|
|
226
|
+
pooling_layer = self.sequence_poolings[feature.name]
|
|
189
227
|
feature_mask = InputMask()(x, feature, seq_input)
|
|
190
228
|
sparse_embeds.append(pooling_layer(seq_emb, feature_mask).unsqueeze(1))
|
|
191
229
|
|
|
@@ -238,17 +276,16 @@ class EmbeddingLayer(nn.Module):
|
|
|
238
276
|
)
|
|
239
277
|
value = x[feature.name].float()
|
|
240
278
|
if value.dim() == 1:
|
|
241
|
-
value = value.unsqueeze(-1)
|
|
279
|
+
value = value.unsqueeze(-1) # [B, 1]
|
|
242
280
|
else:
|
|
243
|
-
value = value.view(value.size(0), -1)
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
if value.shape[1] != expected_in_dim:
|
|
281
|
+
value = value.view(value.size(0), -1) # [B, input_dim]
|
|
282
|
+
input_dim = feature.input_dim
|
|
283
|
+
assert_input_dim = self.dense_input_dims.get(feature.name, input_dim)
|
|
284
|
+
if value.shape[1] != assert_input_dim:
|
|
248
285
|
raise ValueError(
|
|
249
|
-
f"[EmbeddingLayer Error]:Dense feature '{feature.name}' expects {
|
|
286
|
+
f"[EmbeddingLayer Error]:Dense feature '{feature.name}' expects {assert_input_dim} inputs but got {value.shape[1]}."
|
|
250
287
|
)
|
|
251
|
-
if not feature.
|
|
288
|
+
if not feature.use_projection:
|
|
252
289
|
return value
|
|
253
290
|
dense_layer = self.dense_transforms[feature.name]
|
|
254
291
|
return dense_layer(value)
|
|
@@ -257,25 +294,25 @@ class EmbeddingLayer(nn.Module):
|
|
|
257
294
|
self,
|
|
258
295
|
features: list[DenseFeature | SequenceFeature | SparseFeature] | None = None,
|
|
259
296
|
) -> int:
|
|
260
|
-
|
|
261
|
-
|
|
297
|
+
"""Compute the output dimension of the embedding layer."""
|
|
298
|
+
all_features = list(features) if features is not None else self.features
|
|
299
|
+
unique_feats = OrderedDict((feat.name, feat) for feat in all_features)
|
|
262
300
|
dim = 0
|
|
263
301
|
for feat in unique_feats.values():
|
|
264
302
|
if isinstance(feat, DenseFeature):
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
emb_dim = getattr(feat, "embedding_dim", None)
|
|
268
|
-
out_dim = max(int(emb_dim), 1) if emb_dim else in_dim
|
|
303
|
+
if feat.use_projection:
|
|
304
|
+
out_dim = feat.embedding_dim
|
|
269
305
|
else:
|
|
270
|
-
out_dim =
|
|
306
|
+
out_dim = feat.input_dim
|
|
271
307
|
dim += out_dim
|
|
272
308
|
elif isinstance(feat, SequenceFeature) and feat.combiner == "concat":
|
|
273
309
|
dim += feat.embedding_dim * feat.max_len
|
|
274
310
|
else:
|
|
275
|
-
dim += feat.embedding_dim
|
|
311
|
+
dim += feat.embedding_dim
|
|
276
312
|
return dim
|
|
277
313
|
|
|
278
314
|
def get_input_dim(self, features: list[object] | None = None) -> int:
|
|
315
|
+
"""Get the input dimension for the network based on embedding layer's output dimension."""
|
|
279
316
|
return self.compute_output_dim(features) # type: ignore[assignment]
|
|
280
317
|
|
|
281
318
|
@property
|
|
@@ -339,7 +376,8 @@ class ConcatPooling(nn.Module):
|
|
|
339
376
|
def forward(
|
|
340
377
|
self, x: torch.Tensor, mask: torch.Tensor | None = None
|
|
341
378
|
) -> torch.Tensor:
|
|
342
|
-
|
|
379
|
+
pooled = x.flatten(start_dim=1, end_dim=2)
|
|
380
|
+
return pooled
|
|
343
381
|
|
|
344
382
|
|
|
345
383
|
class AveragePooling(nn.Module):
|
|
@@ -349,12 +387,15 @@ class AveragePooling(nn.Module):
|
|
|
349
387
|
def forward(
|
|
350
388
|
self, x: torch.Tensor, mask: torch.Tensor | None = None
|
|
351
389
|
) -> torch.Tensor:
|
|
390
|
+
# mask: matrix with 0/1 values for padding positions
|
|
352
391
|
if mask is None:
|
|
353
|
-
|
|
392
|
+
pooled = torch.mean(x, dim=1)
|
|
354
393
|
else:
|
|
394
|
+
# 0/1 matrix * x
|
|
355
395
|
sum_pooling_matrix = torch.bmm(mask, x).squeeze(1)
|
|
356
396
|
non_padding_length = mask.sum(dim=-1)
|
|
357
|
-
|
|
397
|
+
pooled = sum_pooling_matrix / (non_padding_length.float() + 1e-16)
|
|
398
|
+
return pooled
|
|
358
399
|
|
|
359
400
|
|
|
360
401
|
class SumPooling(nn.Module):
|
|
@@ -365,9 +406,184 @@ class SumPooling(nn.Module):
|
|
|
365
406
|
self, x: torch.Tensor, mask: torch.Tensor | None = None
|
|
366
407
|
) -> torch.Tensor:
|
|
367
408
|
if mask is None:
|
|
368
|
-
|
|
409
|
+
pooled = torch.sum(x, dim=1)
|
|
369
410
|
else:
|
|
370
|
-
|
|
411
|
+
pooled = torch.bmm(mask, x).squeeze(1)
|
|
412
|
+
return pooled
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
class DotProductAttentionPooling(nn.Module):
|
|
416
|
+
"""
|
|
417
|
+
Dot-product attention pooling with a learnable global query vector.
|
|
418
|
+
|
|
419
|
+
Input:
|
|
420
|
+
x: [B, L, D]
|
|
421
|
+
mask: [B, 1, L] or [B, L] with 1 for valid tokens, 0 for padding
|
|
422
|
+
Output:
|
|
423
|
+
pooled: [B, D]
|
|
424
|
+
"""
|
|
425
|
+
|
|
426
|
+
def __init__(self, embedding_dim: int, scale: bool = True, dropout: float = 0.0):
|
|
427
|
+
super().__init__()
|
|
428
|
+
self.embedding_dim = embedding_dim
|
|
429
|
+
self.scale = scale
|
|
430
|
+
self.dropout = nn.Dropout(dropout)
|
|
431
|
+
self.query = nn.Parameter(torch.empty(embedding_dim))
|
|
432
|
+
nn.init.xavier_uniform_(self.query.view(1, -1))
|
|
433
|
+
|
|
434
|
+
def forward(
|
|
435
|
+
self, x: torch.Tensor, mask: torch.Tensor | None = None
|
|
436
|
+
) -> torch.Tensor:
|
|
437
|
+
if x.dim() != 3:
|
|
438
|
+
raise ValueError(
|
|
439
|
+
f"[DotProductAttentionPooling Error]: x must be [B,L,D], got {x.shape}"
|
|
440
|
+
)
|
|
441
|
+
B, L, D = x.shape
|
|
442
|
+
if D != self.embedding_dim:
|
|
443
|
+
raise ValueError(
|
|
444
|
+
f"[DotProductAttentionPooling Error]: embedding_dim mismatch: {D} vs {self.embedding_dim}"
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
q = self.query.view(1, 1, D) # [1,1,D]
|
|
448
|
+
scores = (x * q).sum(dim=-1) # [B,L]
|
|
449
|
+
if self.scale:
|
|
450
|
+
scores = scores / math.sqrt(D)
|
|
451
|
+
|
|
452
|
+
if mask is not None:
|
|
453
|
+
if mask.dim() == 3: # [B,1,L] or [B,L,1]
|
|
454
|
+
if mask.size(1) == 1:
|
|
455
|
+
mask_ = mask.squeeze(1) # [B,L]
|
|
456
|
+
elif mask.size(-1) == 1:
|
|
457
|
+
mask_ = mask.squeeze(-1) # [B,L]
|
|
458
|
+
else:
|
|
459
|
+
raise ValueError(
|
|
460
|
+
f"[DotProductAttentionPooling Error]: bad mask shape: {mask.shape}"
|
|
461
|
+
)
|
|
462
|
+
elif mask.dim() == 2:
|
|
463
|
+
mask_ = mask
|
|
464
|
+
else:
|
|
465
|
+
raise ValueError(
|
|
466
|
+
f"[DotProductAttentionPooling Error]: bad mask dim: {mask.dim()}"
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
mask_ = mask_.to(dtype=torch.bool)
|
|
470
|
+
scores = scores.masked_fill(~mask_, float("-inf")) # mask padding positions
|
|
471
|
+
|
|
472
|
+
attn = torch.softmax(scores, dim=-1) # [B,L]
|
|
473
|
+
attn = self.dropout(attn)
|
|
474
|
+
attn = torch.nan_to_num(attn, nan=0.0) # handle all -inf case
|
|
475
|
+
pooled = torch.bmm(attn.unsqueeze(1), x).squeeze(1) # [B,D]
|
|
476
|
+
return pooled
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
class SelfAttentionPooling(nn.Module):
|
|
480
|
+
"""
|
|
481
|
+
Self-attention (MHA) to contextualize tokens, then attention pooling to [B,D].
|
|
482
|
+
|
|
483
|
+
Input:
|
|
484
|
+
x: [B, L, D]
|
|
485
|
+
mask: [B, 1, L] or [B, L] with 1 for valid tokens, 0 for padding
|
|
486
|
+
Output:
|
|
487
|
+
pooled: [B, D]
|
|
488
|
+
"""
|
|
489
|
+
|
|
490
|
+
def __init__(
|
|
491
|
+
self,
|
|
492
|
+
embedding_dim: int,
|
|
493
|
+
num_heads: int = 2,
|
|
494
|
+
dropout: float = 0.0,
|
|
495
|
+
use_residual: bool = True,
|
|
496
|
+
use_layer_norm: bool = True,
|
|
497
|
+
use_ffn: bool = False,
|
|
498
|
+
):
|
|
499
|
+
super().__init__()
|
|
500
|
+
if embedding_dim % num_heads != 0:
|
|
501
|
+
raise ValueError(
|
|
502
|
+
f"[SelfAttentionPooling Error]: embedding_dim ({embedding_dim}) must be divisible by num_heads ({num_heads})"
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
self.embedding_dim = embedding_dim
|
|
506
|
+
self.use_residual = use_residual
|
|
507
|
+
self.use_layer_norm = use_layer_norm
|
|
508
|
+
self.dropout = nn.Dropout(dropout)
|
|
509
|
+
self.mha = nn.MultiheadAttention(
|
|
510
|
+
embed_dim=embedding_dim,
|
|
511
|
+
num_heads=num_heads,
|
|
512
|
+
dropout=dropout,
|
|
513
|
+
batch_first=True,
|
|
514
|
+
)
|
|
515
|
+
if use_layer_norm:
|
|
516
|
+
self.layer_norm_1 = nn.LayerNorm(embedding_dim)
|
|
517
|
+
else:
|
|
518
|
+
self.layer_norm_1 = None
|
|
519
|
+
|
|
520
|
+
self.use_ffn = use_ffn
|
|
521
|
+
if use_ffn:
|
|
522
|
+
self.ffn = nn.Sequential(
|
|
523
|
+
nn.Linear(embedding_dim, 4 * embedding_dim),
|
|
524
|
+
nn.ReLU(),
|
|
525
|
+
nn.Dropout(dropout),
|
|
526
|
+
nn.Linear(4 * embedding_dim, embedding_dim),
|
|
527
|
+
)
|
|
528
|
+
if use_layer_norm:
|
|
529
|
+
self.layer_norm_2 = nn.LayerNorm(embedding_dim)
|
|
530
|
+
else:
|
|
531
|
+
self.layer_norm_2 = None
|
|
532
|
+
|
|
533
|
+
self.pool = DotProductAttentionPooling(
|
|
534
|
+
embedding_dim=embedding_dim, scale=True, dropout=dropout
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
def forward(
|
|
538
|
+
self, x: torch.Tensor, mask: torch.Tensor | None = None
|
|
539
|
+
) -> torch.Tensor:
|
|
540
|
+
if x.dim() != 3:
|
|
541
|
+
raise ValueError(
|
|
542
|
+
f"[SelfAttentionPooling Error]: x must be [B,L,D], got {x.shape}"
|
|
543
|
+
)
|
|
544
|
+
B, L, D = x.shape
|
|
545
|
+
if D != self.embedding_dim:
|
|
546
|
+
raise ValueError(
|
|
547
|
+
f"[SelfAttentionPooling Error]: embedding_dim mismatch: {D} vs {self.embedding_dim}"
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
key_padding_mask = None
|
|
551
|
+
if mask is not None:
|
|
552
|
+
if mask.dim() == 3:
|
|
553
|
+
if mask.size(1) == 1:
|
|
554
|
+
mask_ = mask.squeeze(1) # [B,L]
|
|
555
|
+
elif mask.size(-1) == 1:
|
|
556
|
+
mask_ = mask.squeeze(-1) # [B,L]
|
|
557
|
+
else:
|
|
558
|
+
raise ValueError(
|
|
559
|
+
f"[SelfAttentionPooling Error]: bad mask shape: {mask.shape}"
|
|
560
|
+
)
|
|
561
|
+
elif mask.dim() == 2:
|
|
562
|
+
mask_ = mask
|
|
563
|
+
else:
|
|
564
|
+
raise ValueError(
|
|
565
|
+
f"[SelfAttentionPooling Error]: bad mask dim: {mask.dim()}"
|
|
566
|
+
)
|
|
567
|
+
key_padding_mask = ~mask_.to(dtype=torch.bool) # True = padding
|
|
568
|
+
|
|
569
|
+
attn_out, _ = self.mha(
|
|
570
|
+
x, x, x, key_padding_mask=key_padding_mask, need_weights=False
|
|
571
|
+
)
|
|
572
|
+
if self.use_residual:
|
|
573
|
+
x = x + self.dropout(attn_out)
|
|
574
|
+
else:
|
|
575
|
+
x = self.dropout(attn_out)
|
|
576
|
+
if self.layer_norm_1 is not None:
|
|
577
|
+
x = self.layer_norm_1(x)
|
|
578
|
+
|
|
579
|
+
if self.use_ffn:
|
|
580
|
+
ffn_out = self.ffn(x)
|
|
581
|
+
x = x + self.dropout(ffn_out)
|
|
582
|
+
if self.layer_norm_2 is not None:
|
|
583
|
+
x = self.layer_norm_2(x)
|
|
584
|
+
|
|
585
|
+
pooled = self.pool(x, mask=mask)
|
|
586
|
+
return pooled
|
|
371
587
|
|
|
372
588
|
|
|
373
589
|
class MLP(nn.Module):
|
|
@@ -377,10 +593,45 @@ class MLP(nn.Module):
|
|
|
377
593
|
output_layer: bool = True,
|
|
378
594
|
dims: list[int] | None = None,
|
|
379
595
|
dropout: float = 0.0,
|
|
380
|
-
activation:
|
|
596
|
+
activation: Literal[
|
|
597
|
+
"dice",
|
|
598
|
+
"relu",
|
|
599
|
+
"relu6",
|
|
600
|
+
"elu",
|
|
601
|
+
"selu",
|
|
602
|
+
"leaky_relu",
|
|
603
|
+
"prelu",
|
|
604
|
+
"gelu",
|
|
605
|
+
"sigmoid",
|
|
606
|
+
"tanh",
|
|
607
|
+
"softplus",
|
|
608
|
+
"softsign",
|
|
609
|
+
"hardswish",
|
|
610
|
+
"mish",
|
|
611
|
+
"silu",
|
|
612
|
+
"swish",
|
|
613
|
+
"hardsigmoid",
|
|
614
|
+
"tanhshrink",
|
|
615
|
+
"softshrink",
|
|
616
|
+
"none",
|
|
617
|
+
"linear",
|
|
618
|
+
"identity",
|
|
619
|
+
] = "relu",
|
|
381
620
|
use_norm: bool = True,
|
|
382
|
-
norm_type:
|
|
621
|
+
norm_type: Literal["batch_norm", "layer_norm"] = "layer_norm",
|
|
383
622
|
):
|
|
623
|
+
"""
|
|
624
|
+
Multi-Layer Perceptron (MLP) module.
|
|
625
|
+
|
|
626
|
+
Args:
|
|
627
|
+
input_dim: Dimension of the input features.
|
|
628
|
+
output_layer: Whether to include the final output layer. If False, the MLP will output the last hidden layer, else it will output a single value.
|
|
629
|
+
dims: List of hidden layer dimensions. If None, no hidden layers are added.
|
|
630
|
+
dropout: Dropout rate between layers.
|
|
631
|
+
activation: Activation function to use between layers.
|
|
632
|
+
use_norm: Whether to use normalization layers.
|
|
633
|
+
norm_type: Type of normalization to use ("batch_norm" or "layer_norm").
|
|
634
|
+
"""
|
|
384
635
|
super().__init__()
|
|
385
636
|
if dims is None:
|
|
386
637
|
dims = []
|
|
@@ -457,7 +708,12 @@ class SENETLayer(nn.Module):
|
|
|
457
708
|
|
|
458
709
|
class BiLinearInteractionLayer(nn.Module):
|
|
459
710
|
def __init__(
|
|
460
|
-
self,
|
|
711
|
+
self,
|
|
712
|
+
input_dim: int,
|
|
713
|
+
num_fields: int,
|
|
714
|
+
bilinear_type: Literal[
|
|
715
|
+
"field_all", "field_each", "field_interaction"
|
|
716
|
+
] = "field_interaction",
|
|
461
717
|
):
|
|
462
718
|
super(BiLinearInteractionLayer, self).__init__()
|
|
463
719
|
self.bilinear_type = bilinear_type
|
|
@@ -531,14 +787,16 @@ class MultiHeadSelfAttention(nn.Module):
|
|
|
531
787
|
self.use_residual = use_residual
|
|
532
788
|
self.dropout_rate = dropout
|
|
533
789
|
|
|
534
|
-
self.
|
|
790
|
+
self.q_proj = nn.Linear(
|
|
535
791
|
embedding_dim, embedding_dim, bias=False
|
|
536
792
|
) # Query projection
|
|
537
|
-
self.
|
|
538
|
-
|
|
793
|
+
self.k_proj = nn.Linear(
|
|
794
|
+
embedding_dim, embedding_dim, bias=False
|
|
795
|
+
) # Key projection
|
|
796
|
+
self.v_proj = nn.Linear(
|
|
539
797
|
embedding_dim, embedding_dim, bias=False
|
|
540
798
|
) # Value projection
|
|
541
|
-
self.
|
|
799
|
+
self.out_proj = nn.Linear(
|
|
542
800
|
embedding_dim, embedding_dim, bias=False
|
|
543
801
|
) # Output projection
|
|
544
802
|
|
|
@@ -557,15 +815,15 @@ class MultiHeadSelfAttention(nn.Module):
|
|
|
557
815
|
# x: [Batch, Length, Dim]
|
|
558
816
|
B, L, D = x.shape
|
|
559
817
|
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
818
|
+
q = self.q_proj(x)
|
|
819
|
+
k = self.k_proj(x)
|
|
820
|
+
v = self.v_proj(x)
|
|
563
821
|
|
|
564
|
-
|
|
822
|
+
q = q.view(B, L, self.num_heads, self.head_dim).transpose(
|
|
565
823
|
1, 2
|
|
566
824
|
) # [Batch, Heads, Length, head_dim]
|
|
567
|
-
|
|
568
|
-
|
|
825
|
+
k = k.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
|
|
826
|
+
v = v.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
|
|
569
827
|
|
|
570
828
|
key_padding_mask = None
|
|
571
829
|
if attention_mask is not None:
|
|
@@ -582,22 +840,22 @@ class MultiHeadSelfAttention(nn.Module):
|
|
|
582
840
|
|
|
583
841
|
if self.use_flash_attention:
|
|
584
842
|
attn = F.scaled_dot_product_attention(
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
843
|
+
q,
|
|
844
|
+
k,
|
|
845
|
+
v,
|
|
588
846
|
attn_mask=attn_mask,
|
|
589
847
|
dropout_p=self.dropout_rate if self.training else 0.0,
|
|
590
848
|
) # [B,H,L,dh]
|
|
591
849
|
else:
|
|
592
|
-
scores = torch.matmul(
|
|
850
|
+
scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5)
|
|
593
851
|
if attn_mask is not None:
|
|
594
852
|
scores = scores.masked_fill(attn_mask, float("-inf"))
|
|
595
853
|
attn_weights = torch.softmax(scores, dim=-1)
|
|
596
854
|
attn_weights = self.dropout(attn_weights)
|
|
597
|
-
attn = torch.matmul(attn_weights,
|
|
855
|
+
attn = torch.matmul(attn_weights, v) # [B,H,L,dh]
|
|
598
856
|
|
|
599
857
|
attn = attn.transpose(1, 2).contiguous().view(B, L, D)
|
|
600
|
-
out = self.
|
|
858
|
+
out = self.out_proj(attn)
|
|
601
859
|
|
|
602
860
|
if self.use_residual:
|
|
603
861
|
out = out + x
|
|
@@ -620,7 +878,30 @@ class AttentionPoolingLayer(nn.Module):
|
|
|
620
878
|
self,
|
|
621
879
|
embedding_dim: int,
|
|
622
880
|
hidden_units: list = [80, 40],
|
|
623
|
-
activation:
|
|
881
|
+
activation: Literal[
|
|
882
|
+
"dice",
|
|
883
|
+
"relu",
|
|
884
|
+
"relu6",
|
|
885
|
+
"elu",
|
|
886
|
+
"selu",
|
|
887
|
+
"leaky_relu",
|
|
888
|
+
"prelu",
|
|
889
|
+
"gelu",
|
|
890
|
+
"sigmoid",
|
|
891
|
+
"tanh",
|
|
892
|
+
"softplus",
|
|
893
|
+
"softsign",
|
|
894
|
+
"hardswish",
|
|
895
|
+
"mish",
|
|
896
|
+
"silu",
|
|
897
|
+
"swish",
|
|
898
|
+
"hardsigmoid",
|
|
899
|
+
"tanhshrink",
|
|
900
|
+
"softshrink",
|
|
901
|
+
"none",
|
|
902
|
+
"linear",
|
|
903
|
+
"identity",
|
|
904
|
+
] = "sigmoid",
|
|
624
905
|
use_softmax: bool = False,
|
|
625
906
|
):
|
|
626
907
|
super().__init__()
|