nextrec 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/features.py +10 -23
- nextrec/basic/layers.py +18 -61
- nextrec/basic/loggers.py +1 -1
- nextrec/basic/metrics.py +55 -33
- nextrec/basic/model.py +258 -394
- nextrec/data/__init__.py +2 -2
- nextrec/data/data_utils.py +80 -4
- nextrec/data/dataloader.py +36 -57
- nextrec/data/preprocessor.py +5 -4
- nextrec/models/generative/__init__.py +5 -0
- nextrec/models/generative/hstu.py +399 -0
- nextrec/models/match/dssm.py +2 -2
- nextrec/models/match/dssm_v2.py +2 -2
- nextrec/models/match/mind.py +2 -2
- nextrec/models/match/sdm.py +2 -2
- nextrec/models/match/youtube_dnn.py +2 -2
- nextrec/models/multi_task/esmm.py +1 -1
- nextrec/models/multi_task/mmoe.py +1 -1
- nextrec/models/multi_task/ple.py +1 -1
- nextrec/models/multi_task/poso.py +1 -1
- nextrec/models/multi_task/share_bottom.py +1 -1
- nextrec/models/ranking/afm.py +1 -1
- nextrec/models/ranking/autoint.py +1 -1
- nextrec/models/ranking/dcn.py +1 -1
- nextrec/models/ranking/deepfm.py +1 -1
- nextrec/models/ranking/dien.py +1 -1
- nextrec/models/ranking/din.py +1 -1
- nextrec/models/ranking/fibinet.py +1 -1
- nextrec/models/ranking/fm.py +1 -1
- nextrec/models/ranking/masknet.py +2 -2
- nextrec/models/ranking/pnn.py +1 -1
- nextrec/models/ranking/widedeep.py +1 -1
- nextrec/models/ranking/xdeepfm.py +1 -1
- nextrec/utils/__init__.py +2 -1
- nextrec/utils/common.py +21 -2
- nextrec/utils/optimizer.py +7 -3
- {nextrec-0.3.1.dist-info → nextrec-0.3.3.dist-info}/METADATA +10 -4
- nextrec-0.3.3.dist-info/RECORD +57 -0
- nextrec-0.3.1.dist-info/RECORD +0 -56
- {nextrec-0.3.1.dist-info → nextrec-0.3.3.dist-info}/WHEEL +0 -0
- {nextrec-0.3.1.dist-info → nextrec-0.3.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,399 @@
|
|
|
1
|
+
"""
|
|
2
|
+
[Info: this version is not released yet, i need to more research on source code and paper]
|
|
3
|
+
Date: create on 01/12/2025
|
|
4
|
+
Checkpoint: edit on 01/12/2025
|
|
5
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
6
|
+
Reference:
|
|
7
|
+
[1] Meta AI. Generative Recommenders (HSTU encoder) — https://github.com/meta-recsys/generative-recommenders
|
|
8
|
+
[2] Ma W, Li P, Chen C, et al. Actions speak louder than words: Trillion-parameter sequential transducers for generative recommendations. arXiv:2402.17152.
|
|
9
|
+
|
|
10
|
+
Hierarchical Sequential Transduction Unit (HSTU) is the core encoder behind
|
|
11
|
+
Meta’s Generative Recommenders. It replaces softmax attention with lightweight
|
|
12
|
+
pointwise activations, enabling extremely deep stacks on long behavior sequences.
|
|
13
|
+
|
|
14
|
+
In each HSTU layer:
|
|
15
|
+
(1) Tokens are projected into four streams U, V, Q, K via a shared feed-forward block
|
|
16
|
+
(2) Softmax-free interactions combine QK^T with Relative Attention Bias (RAB) to encode distance
|
|
17
|
+
(3) Aggregated context is modulated by U-gating and mapped back through an output projection
|
|
18
|
+
|
|
19
|
+
Stacking layers yields an efficient causal encoder for next-item
|
|
20
|
+
generation. With a tied-embedding LM head, HSTU forms
|
|
21
|
+
a full generative recommendation model.
|
|
22
|
+
|
|
23
|
+
Key Advantages:
|
|
24
|
+
- Softmax-free attention scales better on deep/long sequences
|
|
25
|
+
- RAB captures temporal structure without extra attention heads
|
|
26
|
+
- Causal masking and padding-aware normalization fit real logs
|
|
27
|
+
- Weight tying reduces parameters and stabilizes training
|
|
28
|
+
- Serves as a drop-in backbone for generative recommendation
|
|
29
|
+
|
|
30
|
+
HSTU(层次化序列转导单元)是 Meta 生成式推荐的核心编码器,
|
|
31
|
+
用点式激活替代 softmax 注意力,可在长序列上轻松堆叠深层结构。
|
|
32
|
+
|
|
33
|
+
单层 HSTU 的主要步骤:
|
|
34
|
+
(1) 将输入一次性映射到 U、V、Q、K 四条通路
|
|
35
|
+
(2) 利用不含 softmax 的 QK^T 结合相对位置偏置(RAB)建模距离信息
|
|
36
|
+
(3) 用 U 对聚合上下文进行门控,再映射回输出空间
|
|
37
|
+
|
|
38
|
+
多层堆叠后,可得到高效的因果编码器;与绑权 LM 头配合即可完成 next-item 预测。
|
|
39
|
+
|
|
40
|
+
主要优势:
|
|
41
|
+
- 摆脱 softmax,在长序列、深层模型上更易扩展
|
|
42
|
+
- 相对位置偏置稳健刻画时序结构
|
|
43
|
+
- 因果 mask 与 padding 感知归一化贴合真实日志
|
|
44
|
+
- 绑权输出头降低参数量并提升训练稳定性
|
|
45
|
+
- 直接作为生成式推荐的骨干网络
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
from __future__ import annotations
|
|
49
|
+
|
|
50
|
+
import math
|
|
51
|
+
from typing import Optional
|
|
52
|
+
|
|
53
|
+
import torch
|
|
54
|
+
import torch.nn as nn
|
|
55
|
+
import torch.nn.functional as F
|
|
56
|
+
|
|
57
|
+
from nextrec.basic.model import BaseModel
|
|
58
|
+
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _relative_position_bucket(
|
|
62
|
+
relative_position: torch.Tensor,
|
|
63
|
+
num_buckets: int = 32,
|
|
64
|
+
max_distance: int = 128,
|
|
65
|
+
) -> torch.Tensor:
|
|
66
|
+
"""
|
|
67
|
+
map the relative position (i-j) to a bucket in [0, num_buckets).
|
|
68
|
+
"""
|
|
69
|
+
# only need the negative part for causal attention
|
|
70
|
+
n = -relative_position
|
|
71
|
+
n = torch.clamp(n, min=0)
|
|
72
|
+
|
|
73
|
+
# when the distance is small, keep it exact
|
|
74
|
+
max_exact = num_buckets // 2
|
|
75
|
+
is_small = n < max_exact
|
|
76
|
+
|
|
77
|
+
# when the distance is too far, do log scaling
|
|
78
|
+
large_val = max_exact + ((torch.log(n.float() / max_exact + 1e-6) / math.log(max_distance / max_exact)) * (num_buckets - max_exact)).long()
|
|
79
|
+
large_val = torch.clamp(large_val, max=num_buckets - 1)
|
|
80
|
+
|
|
81
|
+
buckets = torch.where(is_small, n.long(), large_val)
|
|
82
|
+
return buckets
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class RelativePositionBias(nn.Module):
|
|
86
|
+
"""
|
|
87
|
+
Compute relative position bias (RAB) for HSTU attention.
|
|
88
|
+
The input is the sequence length T, output is [1, num_heads, seq_len, seq_len].
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
num_heads: int,
|
|
94
|
+
num_buckets: int = 32,
|
|
95
|
+
max_distance: int = 128,
|
|
96
|
+
):
|
|
97
|
+
super().__init__()
|
|
98
|
+
self.num_heads = num_heads
|
|
99
|
+
self.num_buckets = num_buckets
|
|
100
|
+
self.max_distance = max_distance
|
|
101
|
+
self.embedding = nn.Embedding(num_buckets, num_heads)
|
|
102
|
+
|
|
103
|
+
def forward(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
|
104
|
+
# positions: [T]
|
|
105
|
+
ctx = torch.arange(seq_len, device=device)[:, None]
|
|
106
|
+
mem = torch.arange(seq_len, device=device)[None, :]
|
|
107
|
+
rel_pos = mem - ctx # a matrix to describe all relative positions for each [i,j] pair, shape = [seq_len, seq_len]
|
|
108
|
+
buckets = _relative_position_bucket(rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance,) # map to buckets
|
|
109
|
+
values = self.embedding(buckets) # embedding vector for each [i,j] pair, shape = [seq_len, seq_len, embedding_dim=num_heads]
|
|
110
|
+
return values.permute(2, 0, 1).unsqueeze(0) # [1, num_heads, seq_len, seq_len]
|
|
111
|
+
|
|
112
|
+
class HSTUPointwiseAttention(nn.Module):
|
|
113
|
+
"""
|
|
114
|
+
Pointwise aggregation attention that implements HSTU without softmax:
|
|
115
|
+
1) [U, V, Q, K] = split( φ1(f1(X)) ), U: gate, V: value, Q: query, K: key
|
|
116
|
+
2) AV = φ2(QK^T + rab) V / N, av is attention-weighted value
|
|
117
|
+
3) Y = f2( Norm(AV) ⊙ U ), y is output
|
|
118
|
+
φ1, φ2 use SiLU; Norm uses LayerNorm.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
def __init__(
|
|
122
|
+
self,
|
|
123
|
+
d_model: int,
|
|
124
|
+
num_heads: int,
|
|
125
|
+
dropout: float = 0.1,
|
|
126
|
+
alpha: float | None = None
|
|
127
|
+
):
|
|
128
|
+
super().__init__()
|
|
129
|
+
if d_model % num_heads != 0:
|
|
130
|
+
raise ValueError(f"[HSTUPointwiseAttention Error] d_model({d_model}) % num_heads({num_heads}) != 0")
|
|
131
|
+
|
|
132
|
+
self.d_model = d_model
|
|
133
|
+
self.num_heads = num_heads
|
|
134
|
+
self.d_head = d_model // num_heads
|
|
135
|
+
self.alpha = alpha if alpha is not None else (self.d_head ** -0.5)
|
|
136
|
+
# project input to 4 * d_model for U, V, Q, K
|
|
137
|
+
self.in_proj = nn.Linear(d_model, 4 * d_model, bias=True)
|
|
138
|
+
# project output back to d_model
|
|
139
|
+
self.out_proj = nn.Linear(d_model, d_model, bias=True)
|
|
140
|
+
self.dropout = nn.Dropout(dropout)
|
|
141
|
+
self.norm = nn.LayerNorm(d_model)
|
|
142
|
+
|
|
143
|
+
def _reshape_heads(self, x: torch.Tensor) -> torch.Tensor:
|
|
144
|
+
"""
|
|
145
|
+
[B, T, D] -> [B, H, T, d_head]
|
|
146
|
+
"""
|
|
147
|
+
B, T, D = x.shape
|
|
148
|
+
return x.view(B, T, self.num_heads, self.d_head).transpose(1, 2)
|
|
149
|
+
|
|
150
|
+
def forward(
|
|
151
|
+
self,
|
|
152
|
+
x: torch.Tensor,
|
|
153
|
+
attn_mask: Optional[torch.Tensor] = None, # [T, T] with 0 or -inf
|
|
154
|
+
key_padding_mask: Optional[torch.Tensor] = None, # [B, T], True = pad
|
|
155
|
+
rab: Optional[torch.Tensor] = None, # [1, H, T, T] or None
|
|
156
|
+
) -> torch.Tensor:
|
|
157
|
+
B, T, D = x.shape
|
|
158
|
+
|
|
159
|
+
# Eq.(1): [U, V, Q, K] = split( φ1(f1(X)) )
|
|
160
|
+
h = F.silu(self.in_proj(x)) # [B, T, 4D]
|
|
161
|
+
U, V, Q, K = h.chunk(4, dim=-1) # each [B, T, D]
|
|
162
|
+
|
|
163
|
+
Qh = self._reshape_heads(Q) # [B, H, T, d_head]
|
|
164
|
+
Kh = self._reshape_heads(K) # [B, H, T, d_head]
|
|
165
|
+
Vh = self._reshape_heads(V) # [B, H, T, d_head]
|
|
166
|
+
Uh = self._reshape_heads(U) # [B, H, T, d_head]
|
|
167
|
+
|
|
168
|
+
# attention logits: QK^T (without 1/sqrt(d) and softmax)
|
|
169
|
+
logits = torch.matmul(Qh, Kh.transpose(-2, -1)) * self.alpha # [B, H, T, T]
|
|
170
|
+
|
|
171
|
+
# add relative position bias (rab^p), and future extensible rab^t
|
|
172
|
+
if rab is not None:
|
|
173
|
+
# rab: [1, H, T, T] or [B, H, T, T]
|
|
174
|
+
logits = logits + rab
|
|
175
|
+
|
|
176
|
+
# construct an "allowed" mask to calculate N
|
|
177
|
+
# 1 indicates that the (query i, key j) pair is a valid attention pair; 0 indicates it is masked out
|
|
178
|
+
allowed = torch.ones_like(logits, dtype=torch.float) # [B, H, T, T]
|
|
179
|
+
|
|
180
|
+
# causal mask: attn_mask is usually an upper triangular matrix of -inf with shape [T, T]
|
|
181
|
+
if attn_mask is not None:
|
|
182
|
+
allowed = allowed * (attn_mask.view(1, 1, T, T) == 0).float()
|
|
183
|
+
logits = logits + attn_mask.view(1, 1, T, T)
|
|
184
|
+
|
|
185
|
+
# padding mask: key_padding_mask is usually [B, T], True = pad
|
|
186
|
+
if key_padding_mask is not None:
|
|
187
|
+
# valid: 1 for non-pad, 0 for pad
|
|
188
|
+
valid = (~key_padding_mask).float() # [B, T]
|
|
189
|
+
valid = valid.view(B, 1, 1, T) # [B, 1, 1, T]
|
|
190
|
+
allowed = allowed * valid
|
|
191
|
+
logits = logits.masked_fill(valid == 0, float("-inf"))
|
|
192
|
+
|
|
193
|
+
# Eq.(2): A(X)V(X) = φ2(QK^T + rab) V(X) / N
|
|
194
|
+
attn = F.silu(logits) # [B, H, T, T]
|
|
195
|
+
denom = allowed.sum(dim=-1, keepdim=True) # [B, H, T, 1]
|
|
196
|
+
denom = denom.clamp(min=1.0)
|
|
197
|
+
|
|
198
|
+
attn = attn / denom # [B, H, T, T]
|
|
199
|
+
AV = torch.matmul(attn, Vh) # [B, H, T, d_head]
|
|
200
|
+
AV = AV.transpose(1, 2).contiguous().view(B, T, D) # reshape back to [B, T, D]
|
|
201
|
+
U_flat = Uh.transpose(1, 2).contiguous().view(B, T, D)
|
|
202
|
+
y = self.out_proj(self.dropout(self.norm(AV) * U_flat)) # [B, T, D]
|
|
203
|
+
return y
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class HSTULayer(nn.Module):
|
|
207
|
+
"""
|
|
208
|
+
HSTUPointwiseAttention + Residual Connection
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
def __init__(
|
|
212
|
+
self,
|
|
213
|
+
d_model: int,
|
|
214
|
+
num_heads: int,
|
|
215
|
+
dropout: float = 0.1,
|
|
216
|
+
use_rab_pos: bool = True,
|
|
217
|
+
rab_num_buckets: int = 32,
|
|
218
|
+
rab_max_distance: int = 128,
|
|
219
|
+
):
|
|
220
|
+
super().__init__()
|
|
221
|
+
self.attn = HSTUPointwiseAttention(d_model=d_model, num_heads=num_heads, dropout=dropout)
|
|
222
|
+
self.dropout = nn.Dropout(dropout)
|
|
223
|
+
self.use_rab_pos = use_rab_pos
|
|
224
|
+
self.rel_pos_bias = (RelativePositionBias(num_heads=num_heads, num_buckets=rab_num_buckets, max_distance=rab_max_distance) if use_rab_pos else None)
|
|
225
|
+
|
|
226
|
+
def forward(
|
|
227
|
+
self,
|
|
228
|
+
x: torch.Tensor,
|
|
229
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
230
|
+
key_padding_mask: Optional[torch.Tensor] = None,
|
|
231
|
+
) -> torch.Tensor:
|
|
232
|
+
"""
|
|
233
|
+
x: [B, T, D]
|
|
234
|
+
"""
|
|
235
|
+
B, T, D = x.shape
|
|
236
|
+
device = x.device
|
|
237
|
+
rab = None
|
|
238
|
+
if self.use_rab_pos:
|
|
239
|
+
rab = self.rel_pos_bias(seq_len=T, device=device) # [1, H, T, T]
|
|
240
|
+
out = self.attn(x=x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, rab=rab)
|
|
241
|
+
return x + self.dropout(out)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
class HSTU(BaseModel):
|
|
245
|
+
"""
|
|
246
|
+
HSTU encoder for next-item prediction in a causal, generative setup.
|
|
247
|
+
Pipeline:
|
|
248
|
+
1) Embed tokens + positions from the behavior history
|
|
249
|
+
2) Apply stacked HSTU layers with causal mask and optional RAB
|
|
250
|
+
3) Use the last valid position to produce next-item logits via tied LM head
|
|
251
|
+
"""
|
|
252
|
+
|
|
253
|
+
@property
|
|
254
|
+
def model_name(self) -> str:
|
|
255
|
+
return "HSTU"
|
|
256
|
+
|
|
257
|
+
@property
|
|
258
|
+
def task_type(self) -> str:
|
|
259
|
+
return "multiclass"
|
|
260
|
+
|
|
261
|
+
def __init__(
|
|
262
|
+
self,
|
|
263
|
+
sequence_features: list[SequenceFeature],
|
|
264
|
+
dense_features: Optional[list[DenseFeature]] = None,
|
|
265
|
+
sparse_features: Optional[list[SparseFeature]] = None,
|
|
266
|
+
d_model: Optional[int] = None,
|
|
267
|
+
num_heads: int = 8,
|
|
268
|
+
num_layers: int = 4,
|
|
269
|
+
max_seq_len: int = 200,
|
|
270
|
+
dropout: float = 0.1,
|
|
271
|
+
# RAB settings
|
|
272
|
+
use_rab_pos: bool = True,
|
|
273
|
+
rab_num_buckets: int = 32,
|
|
274
|
+
rab_max_distance: int = 128,
|
|
275
|
+
|
|
276
|
+
tie_embeddings: bool = True,
|
|
277
|
+
target: Optional[list[str] | str] = None,
|
|
278
|
+
optimizer: str = "adam",
|
|
279
|
+
optimizer_params: Optional[dict] = None,
|
|
280
|
+
scheduler: Optional[str] = None,
|
|
281
|
+
scheduler_params: Optional[dict] = None,
|
|
282
|
+
loss_params: Optional[dict] = None,
|
|
283
|
+
embedding_l1_reg: float = 0.0,
|
|
284
|
+
dense_l1_reg: float = 0.0,
|
|
285
|
+
embedding_l2_reg: float = 0.0,
|
|
286
|
+
dense_l2_reg: float = 0.0,
|
|
287
|
+
device: str = "cpu",
|
|
288
|
+
**kwargs,
|
|
289
|
+
):
|
|
290
|
+
if not sequence_features:
|
|
291
|
+
raise ValueError("[HSTU Error] HSTU requires at least one SequenceFeature (user behavior history).")
|
|
292
|
+
|
|
293
|
+
# demo version: use the first SequenceFeature as the main sequence
|
|
294
|
+
self.history_feature = sequence_features[0]
|
|
295
|
+
|
|
296
|
+
hidden_dim = d_model or max(int(getattr(self.history_feature, "embedding_dim", 0) or 0), 32)
|
|
297
|
+
# Make hidden_dim divisible by num_heads
|
|
298
|
+
if hidden_dim % num_heads != 0:
|
|
299
|
+
hidden_dim = num_heads * math.ceil(hidden_dim / num_heads)
|
|
300
|
+
|
|
301
|
+
self.padding_idx = self.history_feature.padding_idx if self.history_feature.padding_idx is not None else 0
|
|
302
|
+
self.vocab_size = self.history_feature.vocab_size
|
|
303
|
+
self.max_seq_len = max_seq_len
|
|
304
|
+
|
|
305
|
+
super().__init__(
|
|
306
|
+
dense_features=dense_features,
|
|
307
|
+
sparse_features=sparse_features,
|
|
308
|
+
sequence_features=sequence_features,
|
|
309
|
+
target=target,
|
|
310
|
+
task=self.task_type,
|
|
311
|
+
device=device,
|
|
312
|
+
embedding_l1_reg=embedding_l1_reg,
|
|
313
|
+
dense_l1_reg=dense_l1_reg,
|
|
314
|
+
embedding_l2_reg=embedding_l2_reg,
|
|
315
|
+
dense_l2_reg=dense_l2_reg,
|
|
316
|
+
**kwargs,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
# token & position embedding (paper usually includes pos embedding / RAB in encoder)
|
|
320
|
+
self.token_embedding = nn.Embedding(
|
|
321
|
+
num_embeddings=self.vocab_size,
|
|
322
|
+
embedding_dim=hidden_dim,
|
|
323
|
+
padding_idx=self.padding_idx,
|
|
324
|
+
)
|
|
325
|
+
self.position_embedding = nn.Embedding(max_seq_len, hidden_dim)
|
|
326
|
+
self.input_dropout = nn.Dropout(dropout)
|
|
327
|
+
|
|
328
|
+
# HSTU layers
|
|
329
|
+
self.layers = nn.ModuleList([HSTULayer(d_model=hidden_dim, num_heads=num_heads, dropout=dropout, use_rab_pos=use_rab_pos,
|
|
330
|
+
rab_num_buckets=rab_num_buckets, rab_max_distance=rab_max_distance) for _ in range(num_layers)])
|
|
331
|
+
|
|
332
|
+
self.final_norm = nn.LayerNorm(hidden_dim)
|
|
333
|
+
self.lm_head = nn.Linear(hidden_dim, self.vocab_size, bias=False)
|
|
334
|
+
if tie_embeddings:
|
|
335
|
+
self.lm_head.weight = self.token_embedding.weight
|
|
336
|
+
|
|
337
|
+
# causal mask buffer
|
|
338
|
+
self.register_buffer("causal_mask", torch.empty(0), persistent=False)
|
|
339
|
+
self.ignore_index = self.padding_idx if self.padding_idx is not None else -100
|
|
340
|
+
|
|
341
|
+
optimizer_params = optimizer_params or {}
|
|
342
|
+
scheduler_params = scheduler_params or {}
|
|
343
|
+
loss_params = loss_params or {}
|
|
344
|
+
loss_params.setdefault("ignore_index", self.ignore_index)
|
|
345
|
+
|
|
346
|
+
self.compile(optimizer=optimizer, optimizer_params=optimizer_params, scheduler=scheduler, scheduler_params=scheduler_params, loss="crossentropy", loss_params=loss_params)
|
|
347
|
+
self.register_regularization_weights(embedding_attr="token_embedding", include_modules=["layers", "lm_head"])
|
|
348
|
+
|
|
349
|
+
def _build_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
|
350
|
+
"""
|
|
351
|
+
build causal mask of shape [T, T]: upper triangle is -inf, others are 0.
|
|
352
|
+
This will be added to the logits to simulate causal structure.
|
|
353
|
+
"""
|
|
354
|
+
if self.causal_mask.numel() == 0 or self.causal_mask.size(0) < seq_len:
|
|
355
|
+
mask = torch.full((seq_len, seq_len), float("-inf"), device=device)
|
|
356
|
+
mask = torch.triu(mask, diagonal=1)
|
|
357
|
+
self.causal_mask = mask
|
|
358
|
+
return self.causal_mask[:seq_len, :seq_len]
|
|
359
|
+
|
|
360
|
+
def _trim_sequence(self, seq: torch.Tensor) -> torch.Tensor:
|
|
361
|
+
if seq.size(1) <= self.max_seq_len:
|
|
362
|
+
return seq
|
|
363
|
+
return seq[:, -self.max_seq_len :]
|
|
364
|
+
|
|
365
|
+
def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
366
|
+
seq = x[self.history_feature.name].long() # [B, T_raw]
|
|
367
|
+
seq = self._trim_sequence(seq) # [B, T]
|
|
368
|
+
|
|
369
|
+
B, T = seq.shape
|
|
370
|
+
device = seq.device
|
|
371
|
+
# position ids: [B, T]
|
|
372
|
+
pos_ids = torch.arange(T, device=device).unsqueeze(0).expand(B, -1)
|
|
373
|
+
token_emb = self.token_embedding(seq) # [B, T, D]
|
|
374
|
+
pos_emb = self.position_embedding(pos_ids) # [B, T, D]
|
|
375
|
+
hidden_states = self.input_dropout(token_emb + pos_emb)
|
|
376
|
+
|
|
377
|
+
# padding mask:True = pad
|
|
378
|
+
padding_mask = seq.eq(self.padding_idx) # [B, T]
|
|
379
|
+
attn_mask = self._build_causal_mask(seq_len=T, device=device) # [T, T]
|
|
380
|
+
|
|
381
|
+
for layer in self.layers:
|
|
382
|
+
hidden_states = layer(x=hidden_states, attn_mask=attn_mask, key_padding_mask=padding_mask)
|
|
383
|
+
hidden_states = self.final_norm(hidden_states) # [B, T, D]
|
|
384
|
+
|
|
385
|
+
valid_lengths = (~padding_mask).sum(dim=1) # [B]
|
|
386
|
+
last_index = (valid_lengths - 1).clamp(min=0)
|
|
387
|
+
last_hidden = hidden_states[torch.arange(B, device=device), last_index] # [B, D]
|
|
388
|
+
|
|
389
|
+
logits = self.lm_head(last_hidden) # [B, vocab_size]
|
|
390
|
+
return logits
|
|
391
|
+
|
|
392
|
+
def compute_loss(self, y_pred, y_true):
|
|
393
|
+
"""
|
|
394
|
+
y_true: [B] or [B, 1], the id of the next item.
|
|
395
|
+
"""
|
|
396
|
+
if y_true is None:
|
|
397
|
+
raise ValueError("[HSTU-compute_loss] Training requires y_true (next item id).")
|
|
398
|
+
labels = y_true.view(-1).long()
|
|
399
|
+
return self.loss_fn[0](y_pred, labels)
|
nextrec/models/match/dssm.py
CHANGED
|
@@ -143,11 +143,11 @@ class DSSM(BaseMatchModel):
|
|
|
143
143
|
activation=dnn_activation
|
|
144
144
|
)
|
|
145
145
|
|
|
146
|
-
self.
|
|
146
|
+
self.register_regularization_weights(
|
|
147
147
|
embedding_attr='user_embedding',
|
|
148
148
|
include_modules=['user_dnn']
|
|
149
149
|
)
|
|
150
|
-
self.
|
|
150
|
+
self.register_regularization_weights(
|
|
151
151
|
embedding_attr='item_embedding',
|
|
152
152
|
include_modules=['item_dnn']
|
|
153
153
|
)
|
nextrec/models/match/dssm_v2.py
CHANGED
|
@@ -134,11 +134,11 @@ class DSSM_v2(BaseMatchModel):
|
|
|
134
134
|
activation=dnn_activation
|
|
135
135
|
)
|
|
136
136
|
|
|
137
|
-
self.
|
|
137
|
+
self.register_regularization_weights(
|
|
138
138
|
embedding_attr='user_embedding',
|
|
139
139
|
include_modules=['user_dnn']
|
|
140
140
|
)
|
|
141
|
-
self.
|
|
141
|
+
self.register_regularization_weights(
|
|
142
142
|
embedding_attr='item_embedding',
|
|
143
143
|
include_modules=['item_dnn']
|
|
144
144
|
)
|
nextrec/models/match/mind.py
CHANGED
|
@@ -258,11 +258,11 @@ class MIND(BaseMatchModel):
|
|
|
258
258
|
else:
|
|
259
259
|
self.item_dnn = None
|
|
260
260
|
|
|
261
|
-
self.
|
|
261
|
+
self.register_regularization_weights(
|
|
262
262
|
embedding_attr='user_embedding',
|
|
263
263
|
include_modules=['capsule_network']
|
|
264
264
|
)
|
|
265
|
-
self.
|
|
265
|
+
self.register_regularization_weights(
|
|
266
266
|
embedding_attr='item_embedding',
|
|
267
267
|
include_modules=['item_dnn'] if self.item_dnn else []
|
|
268
268
|
)
|
nextrec/models/match/sdm.py
CHANGED
|
@@ -176,11 +176,11 @@ class SDM(BaseMatchModel):
|
|
|
176
176
|
else:
|
|
177
177
|
self.item_dnn = None
|
|
178
178
|
|
|
179
|
-
self.
|
|
179
|
+
self.register_regularization_weights(
|
|
180
180
|
embedding_attr='user_embedding',
|
|
181
181
|
include_modules=['rnn', 'user_dnn']
|
|
182
182
|
)
|
|
183
|
-
self.
|
|
183
|
+
self.register_regularization_weights(
|
|
184
184
|
embedding_attr='item_embedding',
|
|
185
185
|
include_modules=['item_dnn'] if self.item_dnn else []
|
|
186
186
|
)
|
|
@@ -140,11 +140,11 @@ class YoutubeDNN(BaseMatchModel):
|
|
|
140
140
|
activation=dnn_activation
|
|
141
141
|
)
|
|
142
142
|
|
|
143
|
-
self.
|
|
143
|
+
self.register_regularization_weights(
|
|
144
144
|
embedding_attr='user_embedding',
|
|
145
145
|
include_modules=['user_dnn']
|
|
146
146
|
)
|
|
147
|
-
self.
|
|
147
|
+
self.register_regularization_weights(
|
|
148
148
|
embedding_attr='item_embedding',
|
|
149
149
|
include_modules=['item_dnn']
|
|
150
150
|
)
|
|
@@ -128,7 +128,7 @@ class ESMM(BaseModel):
|
|
|
128
128
|
self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
|
|
129
129
|
self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1, 1])
|
|
130
130
|
# Register regularization weights
|
|
131
|
-
self.
|
|
131
|
+
self.register_regularization_weights(embedding_attr='embedding', include_modules=['ctr_tower', 'cvr_tower'])
|
|
132
132
|
self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
|
|
133
133
|
|
|
134
134
|
def forward(self, x):
|
|
@@ -146,7 +146,7 @@ class MMOE(BaseModel):
|
|
|
146
146
|
self.towers.append(tower)
|
|
147
147
|
self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1] * self.num_tasks)
|
|
148
148
|
# Register regularization weights
|
|
149
|
-
self.
|
|
149
|
+
self.register_regularization_weights(embedding_attr='embedding', include_modules=['experts', 'gates', 'towers'])
|
|
150
150
|
self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params,)
|
|
151
151
|
|
|
152
152
|
def forward(self, x):
|
nextrec/models/multi_task/ple.py
CHANGED
|
@@ -249,7 +249,7 @@ class PLE(BaseModel):
|
|
|
249
249
|
self.towers.append(tower)
|
|
250
250
|
self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1] * self.num_tasks)
|
|
251
251
|
# Register regularization weights
|
|
252
|
-
self.
|
|
252
|
+
self.register_regularization_weights(embedding_attr='embedding', include_modules=['cgc_layers', 'towers'])
|
|
253
253
|
self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=self.loss, loss_params=loss_params)
|
|
254
254
|
|
|
255
255
|
def forward(self, x):
|
|
@@ -389,7 +389,7 @@ class POSO(BaseModel):
|
|
|
389
389
|
self.tower_heads = None
|
|
390
390
|
self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1] * self.num_tasks,)
|
|
391
391
|
include_modules = ["towers", "tower_heads"] if self.architecture == "mlp" else ["mmoe", "towers"]
|
|
392
|
-
self.
|
|
392
|
+
self.register_regularization_weights(embedding_attr="embedding", include_modules=include_modules)
|
|
393
393
|
self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
|
|
394
394
|
|
|
395
395
|
def forward(self, x):
|
|
@@ -122,7 +122,7 @@ class ShareBottom(BaseModel):
|
|
|
122
122
|
self.towers.append(tower)
|
|
123
123
|
self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1] * self.num_tasks)
|
|
124
124
|
# Register regularization weights
|
|
125
|
-
self.
|
|
125
|
+
self.register_regularization_weights(embedding_attr='embedding', include_modules=['bottom', 'towers'])
|
|
126
126
|
self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
|
|
127
127
|
|
|
128
128
|
def forward(self, x):
|
nextrec/models/ranking/afm.py
CHANGED
|
@@ -81,7 +81,7 @@ class AFM(BaseModel):
|
|
|
81
81
|
self.prediction_layer = PredictionLayer(task_type=self.task_type)
|
|
82
82
|
|
|
83
83
|
# Register regularization weights
|
|
84
|
-
self.
|
|
84
|
+
self.register_regularization_weights(
|
|
85
85
|
embedding_attr='embedding',
|
|
86
86
|
include_modules=['linear', 'attention_linear', 'attention_p', 'output_projection']
|
|
87
87
|
)
|
|
@@ -150,7 +150,7 @@ class AutoInt(BaseModel):
|
|
|
150
150
|
self.prediction_layer = PredictionLayer(task_type=self.task_type)
|
|
151
151
|
|
|
152
152
|
# Register regularization weights
|
|
153
|
-
self.
|
|
153
|
+
self.register_regularization_weights(
|
|
154
154
|
embedding_attr='embedding',
|
|
155
155
|
include_modules=['projection_layers', 'attention_layers', 'fc']
|
|
156
156
|
)
|
nextrec/models/ranking/dcn.py
CHANGED
|
@@ -109,7 +109,7 @@ class DCN(BaseModel):
|
|
|
109
109
|
self.prediction_layer = PredictionLayer(task_type=self.task_type)
|
|
110
110
|
|
|
111
111
|
# Register regularization weights
|
|
112
|
-
self.
|
|
112
|
+
self.register_regularization_weights(
|
|
113
113
|
embedding_attr='embedding',
|
|
114
114
|
include_modules=['cross_network', 'mlp', 'final_layer']
|
|
115
115
|
)
|
nextrec/models/ranking/deepfm.py
CHANGED
|
@@ -107,7 +107,7 @@ class DeepFM(BaseModel):
|
|
|
107
107
|
self.prediction_layer = PredictionLayer(task_type=self.task_type)
|
|
108
108
|
|
|
109
109
|
# Register regularization weights
|
|
110
|
-
self.
|
|
110
|
+
self.register_regularization_weights(embedding_attr='embedding', include_modules=['linear', 'mlp'])
|
|
111
111
|
self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
|
|
112
112
|
|
|
113
113
|
def forward(self, x):
|
nextrec/models/ranking/dien.py
CHANGED
|
@@ -237,7 +237,7 @@ class DIEN(BaseModel):
|
|
|
237
237
|
self.mlp = MLP(input_dim=mlp_input_dim, **mlp_params)
|
|
238
238
|
self.prediction_layer = PredictionLayer(task_type=self.task_type)
|
|
239
239
|
# Register regularization weights
|
|
240
|
-
self.
|
|
240
|
+
self.register_regularization_weights(embedding_attr='embedding', include_modules=['interest_extractor', 'interest_evolution', 'attention_layer', 'mlp', 'candidate_proj'])
|
|
241
241
|
self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
|
|
242
242
|
|
|
243
243
|
def forward(self, x):
|
nextrec/models/ranking/din.py
CHANGED
|
@@ -108,7 +108,7 @@ class DIN(BaseModel):
|
|
|
108
108
|
self.prediction_layer = PredictionLayer(task_type=self.task_type)
|
|
109
109
|
|
|
110
110
|
# Register regularization weights
|
|
111
|
-
self.
|
|
111
|
+
self.register_regularization_weights(
|
|
112
112
|
embedding_attr='embedding',
|
|
113
113
|
include_modules=['attention', 'mlp', 'candidate_attention_proj']
|
|
114
114
|
)
|
|
@@ -104,7 +104,7 @@ class FiBiNET(BaseModel):
|
|
|
104
104
|
self.prediction_layer = PredictionLayer(task_type=self.task_type)
|
|
105
105
|
|
|
106
106
|
# Register regularization weights
|
|
107
|
-
self.
|
|
107
|
+
self.register_regularization_weights(
|
|
108
108
|
embedding_attr='embedding',
|
|
109
109
|
include_modules=['linear', 'senet', 'bilinear_standard', 'bilinear_senet', 'mlp']
|
|
110
110
|
)
|
nextrec/models/ranking/fm.py
CHANGED
|
@@ -69,7 +69,7 @@ class FM(BaseModel):
|
|
|
69
69
|
self.prediction_layer = PredictionLayer(task_type=self.task_type)
|
|
70
70
|
|
|
71
71
|
# Register regularization weights
|
|
72
|
-
self.
|
|
72
|
+
self.register_regularization_weights(
|
|
73
73
|
embedding_attr='embedding',
|
|
74
74
|
include_modules=['linear']
|
|
75
75
|
)
|
|
@@ -234,10 +234,10 @@ class MaskNet(BaseModel):
|
|
|
234
234
|
self.prediction_layer = PredictionLayer(task_type=self.task_type)
|
|
235
235
|
|
|
236
236
|
if self.model_type == "serial":
|
|
237
|
-
self.
|
|
237
|
+
self.register_regularization_weights(embedding_attr="embedding", include_modules=["mask_blocks", "output_layer"],)
|
|
238
238
|
# serial
|
|
239
239
|
else:
|
|
240
|
-
self.
|
|
240
|
+
self.register_regularization_weights(embedding_attr="embedding", include_modules=["mask_blocks", "final_mlp"])
|
|
241
241
|
self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
|
|
242
242
|
|
|
243
243
|
def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
|
nextrec/models/ranking/pnn.py
CHANGED
|
@@ -111,7 +111,7 @@ class WideDeep(BaseModel):
|
|
|
111
111
|
self.mlp = MLP(input_dim=input_dim, **mlp_params)
|
|
112
112
|
self.prediction_layer = PredictionLayer(task_type=self.task_type)
|
|
113
113
|
# Register regularization weights
|
|
114
|
-
self.
|
|
114
|
+
self.register_regularization_weights(embedding_attr='embedding', include_modules=['linear', 'mlp'])
|
|
115
115
|
self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
|
|
116
116
|
|
|
117
117
|
def forward(self, x):
|
|
@@ -121,7 +121,7 @@ class xDeepFM(BaseModel):
|
|
|
121
121
|
self.prediction_layer = PredictionLayer(task_type=self.task_type)
|
|
122
122
|
|
|
123
123
|
# Register regularization weights
|
|
124
|
-
self.
|
|
124
|
+
self.register_regularization_weights(
|
|
125
125
|
embedding_attr='embedding',
|
|
126
126
|
include_modules=['linear', 'cin', 'mlp']
|
|
127
127
|
)
|