nextrec 0.2.7__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +4 -8
- nextrec/basic/callback.py +1 -1
- nextrec/basic/features.py +33 -25
- nextrec/basic/layers.py +164 -601
- nextrec/basic/loggers.py +4 -5
- nextrec/basic/metrics.py +39 -115
- nextrec/basic/model.py +257 -177
- nextrec/basic/session.py +1 -5
- nextrec/data/__init__.py +12 -0
- nextrec/data/data_utils.py +3 -27
- nextrec/data/dataloader.py +26 -34
- nextrec/data/preprocessor.py +2 -1
- nextrec/loss/listwise.py +6 -4
- nextrec/loss/loss_utils.py +10 -6
- nextrec/loss/pairwise.py +5 -3
- nextrec/loss/pointwise.py +7 -13
- nextrec/models/generative/__init__.py +5 -0
- nextrec/models/generative/hstu.py +399 -0
- nextrec/models/match/mind.py +110 -1
- nextrec/models/multi_task/esmm.py +46 -27
- nextrec/models/multi_task/mmoe.py +48 -30
- nextrec/models/multi_task/ple.py +156 -141
- nextrec/models/multi_task/poso.py +413 -0
- nextrec/models/multi_task/share_bottom.py +43 -26
- nextrec/models/ranking/__init__.py +2 -0
- nextrec/models/ranking/dcn.py +20 -1
- nextrec/models/ranking/dcn_v2.py +84 -0
- nextrec/models/ranking/deepfm.py +44 -18
- nextrec/models/ranking/dien.py +130 -27
- nextrec/models/ranking/masknet.py +13 -67
- nextrec/models/ranking/widedeep.py +39 -18
- nextrec/models/ranking/xdeepfm.py +34 -1
- nextrec/utils/common.py +26 -1
- nextrec/utils/optimizer.py +7 -3
- nextrec-0.3.2.dist-info/METADATA +312 -0
- nextrec-0.3.2.dist-info/RECORD +57 -0
- nextrec-0.2.7.dist-info/METADATA +0 -281
- nextrec-0.2.7.dist-info/RECORD +0 -54
- {nextrec-0.2.7.dist-info → nextrec-0.3.2.dist-info}/WHEEL +0 -0
- {nextrec-0.2.7.dist-info → nextrec-0.3.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,399 @@
|
|
|
1
|
+
"""
|
|
2
|
+
[Info: this version is not released yet, i need to more research on source code and paper]
|
|
3
|
+
Date: create on 01/12/2025
|
|
4
|
+
Checkpoint: edit on 01/12/2025
|
|
5
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
6
|
+
Reference:
|
|
7
|
+
[1] Meta AI. Generative Recommenders (HSTU encoder) — https://github.com/meta-recsys/generative-recommenders
|
|
8
|
+
[2] Ma W, Li P, Chen C, et al. Actions speak louder than words: Trillion-parameter sequential transducers for generative recommendations. arXiv:2402.17152.
|
|
9
|
+
|
|
10
|
+
Hierarchical Sequential Transduction Unit (HSTU) is the core encoder behind
|
|
11
|
+
Meta’s Generative Recommenders. It replaces softmax attention with lightweight
|
|
12
|
+
pointwise activations, enabling extremely deep stacks on long behavior sequences.
|
|
13
|
+
|
|
14
|
+
In each HSTU layer:
|
|
15
|
+
(1) Tokens are projected into four streams U, V, Q, K via a shared feed-forward block
|
|
16
|
+
(2) Softmax-free interactions combine QK^T with Relative Attention Bias (RAB) to encode distance
|
|
17
|
+
(3) Aggregated context is modulated by U-gating and mapped back through an output projection
|
|
18
|
+
|
|
19
|
+
Stacking layers yields an efficient causal encoder for next-item
|
|
20
|
+
generation. With a tied-embedding LM head, HSTU forms
|
|
21
|
+
a full generative recommendation model.
|
|
22
|
+
|
|
23
|
+
Key Advantages:
|
|
24
|
+
- Softmax-free attention scales better on deep/long sequences
|
|
25
|
+
- RAB captures temporal structure without extra attention heads
|
|
26
|
+
- Causal masking and padding-aware normalization fit real logs
|
|
27
|
+
- Weight tying reduces parameters and stabilizes training
|
|
28
|
+
- Serves as a drop-in backbone for generative recommendation
|
|
29
|
+
|
|
30
|
+
HSTU(层次化序列转导单元)是 Meta 生成式推荐的核心编码器,
|
|
31
|
+
用点式激活替代 softmax 注意力,可在长序列上轻松堆叠深层结构。
|
|
32
|
+
|
|
33
|
+
单层 HSTU 的主要步骤:
|
|
34
|
+
(1) 将输入一次性映射到 U、V、Q、K 四条通路
|
|
35
|
+
(2) 利用不含 softmax 的 QK^T 结合相对位置偏置(RAB)建模距离信息
|
|
36
|
+
(3) 用 U 对聚合上下文进行门控,再映射回输出空间
|
|
37
|
+
|
|
38
|
+
多层堆叠后,可得到高效的因果编码器;与绑权 LM 头配合即可完成 next-item 预测。
|
|
39
|
+
|
|
40
|
+
主要优势:
|
|
41
|
+
- 摆脱 softmax,在长序列、深层模型上更易扩展
|
|
42
|
+
- 相对位置偏置稳健刻画时序结构
|
|
43
|
+
- 因果 mask 与 padding 感知归一化贴合真实日志
|
|
44
|
+
- 绑权输出头降低参数量并提升训练稳定性
|
|
45
|
+
- 直接作为生成式推荐的骨干网络
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
from __future__ import annotations
|
|
49
|
+
|
|
50
|
+
import math
|
|
51
|
+
from typing import Optional
|
|
52
|
+
|
|
53
|
+
import torch
|
|
54
|
+
import torch.nn as nn
|
|
55
|
+
import torch.nn.functional as F
|
|
56
|
+
|
|
57
|
+
from nextrec.basic.model import BaseModel
|
|
58
|
+
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _relative_position_bucket(
|
|
62
|
+
relative_position: torch.Tensor,
|
|
63
|
+
num_buckets: int = 32,
|
|
64
|
+
max_distance: int = 128,
|
|
65
|
+
) -> torch.Tensor:
|
|
66
|
+
"""
|
|
67
|
+
map the relative position (i-j) to a bucket in [0, num_buckets).
|
|
68
|
+
"""
|
|
69
|
+
# only need the negative part for causal attention
|
|
70
|
+
n = -relative_position
|
|
71
|
+
n = torch.clamp(n, min=0)
|
|
72
|
+
|
|
73
|
+
# when the distance is small, keep it exact
|
|
74
|
+
max_exact = num_buckets // 2
|
|
75
|
+
is_small = n < max_exact
|
|
76
|
+
|
|
77
|
+
# when the distance is too far, do log scaling
|
|
78
|
+
large_val = max_exact + ((torch.log(n.float() / max_exact + 1e-6) / math.log(max_distance / max_exact)) * (num_buckets - max_exact)).long()
|
|
79
|
+
large_val = torch.clamp(large_val, max=num_buckets - 1)
|
|
80
|
+
|
|
81
|
+
buckets = torch.where(is_small, n.long(), large_val)
|
|
82
|
+
return buckets
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class RelativePositionBias(nn.Module):
|
|
86
|
+
"""
|
|
87
|
+
Compute relative position bias (RAB) for HSTU attention.
|
|
88
|
+
The input is the sequence length T, output is [1, num_heads, seq_len, seq_len].
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
num_heads: int,
|
|
94
|
+
num_buckets: int = 32,
|
|
95
|
+
max_distance: int = 128,
|
|
96
|
+
):
|
|
97
|
+
super().__init__()
|
|
98
|
+
self.num_heads = num_heads
|
|
99
|
+
self.num_buckets = num_buckets
|
|
100
|
+
self.max_distance = max_distance
|
|
101
|
+
self.embedding = nn.Embedding(num_buckets, num_heads)
|
|
102
|
+
|
|
103
|
+
def forward(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
|
104
|
+
# positions: [T]
|
|
105
|
+
ctx = torch.arange(seq_len, device=device)[:, None]
|
|
106
|
+
mem = torch.arange(seq_len, device=device)[None, :]
|
|
107
|
+
rel_pos = mem - ctx # a matrix to describe all relative positions for each [i,j] pair, shape = [seq_len, seq_len]
|
|
108
|
+
buckets = _relative_position_bucket(rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance,) # map to buckets
|
|
109
|
+
values = self.embedding(buckets) # embedding vector for each [i,j] pair, shape = [seq_len, seq_len, embedding_dim=num_heads]
|
|
110
|
+
return values.permute(2, 0, 1).unsqueeze(0) # [1, num_heads, seq_len, seq_len]
|
|
111
|
+
|
|
112
|
+
class HSTUPointwiseAttention(nn.Module):
|
|
113
|
+
"""
|
|
114
|
+
Pointwise aggregation attention that implements HSTU without softmax:
|
|
115
|
+
1) [U, V, Q, K] = split( φ1(f1(X)) ), U: gate, V: value, Q: query, K: key
|
|
116
|
+
2) AV = φ2(QK^T + rab) V / N, av is attention-weighted value
|
|
117
|
+
3) Y = f2( Norm(AV) ⊙ U ), y is output
|
|
118
|
+
φ1, φ2 use SiLU; Norm uses LayerNorm.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
def __init__(
|
|
122
|
+
self,
|
|
123
|
+
d_model: int,
|
|
124
|
+
num_heads: int,
|
|
125
|
+
dropout: float = 0.1,
|
|
126
|
+
alpha: float | None = None
|
|
127
|
+
):
|
|
128
|
+
super().__init__()
|
|
129
|
+
if d_model % num_heads != 0:
|
|
130
|
+
raise ValueError(f"[HSTUPointwiseAttention Error] d_model({d_model}) % num_heads({num_heads}) != 0")
|
|
131
|
+
|
|
132
|
+
self.d_model = d_model
|
|
133
|
+
self.num_heads = num_heads
|
|
134
|
+
self.d_head = d_model // num_heads
|
|
135
|
+
self.alpha = alpha if alpha is not None else (self.d_head ** -0.5)
|
|
136
|
+
# project input to 4 * d_model for U, V, Q, K
|
|
137
|
+
self.in_proj = nn.Linear(d_model, 4 * d_model, bias=True)
|
|
138
|
+
# project output back to d_model
|
|
139
|
+
self.out_proj = nn.Linear(d_model, d_model, bias=True)
|
|
140
|
+
self.dropout = nn.Dropout(dropout)
|
|
141
|
+
self.norm = nn.LayerNorm(d_model)
|
|
142
|
+
|
|
143
|
+
def _reshape_heads(self, x: torch.Tensor) -> torch.Tensor:
|
|
144
|
+
"""
|
|
145
|
+
[B, T, D] -> [B, H, T, d_head]
|
|
146
|
+
"""
|
|
147
|
+
B, T, D = x.shape
|
|
148
|
+
return x.view(B, T, self.num_heads, self.d_head).transpose(1, 2)
|
|
149
|
+
|
|
150
|
+
def forward(
|
|
151
|
+
self,
|
|
152
|
+
x: torch.Tensor,
|
|
153
|
+
attn_mask: Optional[torch.Tensor] = None, # [T, T] with 0 or -inf
|
|
154
|
+
key_padding_mask: Optional[torch.Tensor] = None, # [B, T], True = pad
|
|
155
|
+
rab: Optional[torch.Tensor] = None, # [1, H, T, T] or None
|
|
156
|
+
) -> torch.Tensor:
|
|
157
|
+
B, T, D = x.shape
|
|
158
|
+
|
|
159
|
+
# Eq.(1): [U, V, Q, K] = split( φ1(f1(X)) )
|
|
160
|
+
h = F.silu(self.in_proj(x)) # [B, T, 4D]
|
|
161
|
+
U, V, Q, K = h.chunk(4, dim=-1) # each [B, T, D]
|
|
162
|
+
|
|
163
|
+
Qh = self._reshape_heads(Q) # [B, H, T, d_head]
|
|
164
|
+
Kh = self._reshape_heads(K) # [B, H, T, d_head]
|
|
165
|
+
Vh = self._reshape_heads(V) # [B, H, T, d_head]
|
|
166
|
+
Uh = self._reshape_heads(U) # [B, H, T, d_head]
|
|
167
|
+
|
|
168
|
+
# attention logits: QK^T (without 1/sqrt(d) and softmax)
|
|
169
|
+
logits = torch.matmul(Qh, Kh.transpose(-2, -1)) * self.alpha # [B, H, T, T]
|
|
170
|
+
|
|
171
|
+
# add relative position bias (rab^p), and future extensible rab^t
|
|
172
|
+
if rab is not None:
|
|
173
|
+
# rab: [1, H, T, T] or [B, H, T, T]
|
|
174
|
+
logits = logits + rab
|
|
175
|
+
|
|
176
|
+
# construct an "allowed" mask to calculate N
|
|
177
|
+
# 1 indicates that the (query i, key j) pair is a valid attention pair; 0 indicates it is masked out
|
|
178
|
+
allowed = torch.ones_like(logits, dtype=torch.float) # [B, H, T, T]
|
|
179
|
+
|
|
180
|
+
# causal mask: attn_mask is usually an upper triangular matrix of -inf with shape [T, T]
|
|
181
|
+
if attn_mask is not None:
|
|
182
|
+
allowed = allowed * (attn_mask.view(1, 1, T, T) == 0).float()
|
|
183
|
+
logits = logits + attn_mask.view(1, 1, T, T)
|
|
184
|
+
|
|
185
|
+
# padding mask: key_padding_mask is usually [B, T], True = pad
|
|
186
|
+
if key_padding_mask is not None:
|
|
187
|
+
# valid: 1 for non-pad, 0 for pad
|
|
188
|
+
valid = (~key_padding_mask).float() # [B, T]
|
|
189
|
+
valid = valid.view(B, 1, 1, T) # [B, 1, 1, T]
|
|
190
|
+
allowed = allowed * valid
|
|
191
|
+
logits = logits.masked_fill(valid == 0, float("-inf"))
|
|
192
|
+
|
|
193
|
+
# Eq.(2): A(X)V(X) = φ2(QK^T + rab) V(X) / N
|
|
194
|
+
attn = F.silu(logits) # [B, H, T, T]
|
|
195
|
+
denom = allowed.sum(dim=-1, keepdim=True) # [B, H, T, 1]
|
|
196
|
+
denom = denom.clamp(min=1.0)
|
|
197
|
+
|
|
198
|
+
attn = attn / denom # [B, H, T, T]
|
|
199
|
+
AV = torch.matmul(attn, Vh) # [B, H, T, d_head]
|
|
200
|
+
AV = AV.transpose(1, 2).contiguous().view(B, T, D) # reshape back to [B, T, D]
|
|
201
|
+
U_flat = Uh.transpose(1, 2).contiguous().view(B, T, D)
|
|
202
|
+
y = self.out_proj(self.dropout(self.norm(AV) * U_flat)) # [B, T, D]
|
|
203
|
+
return y
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class HSTULayer(nn.Module):
|
|
207
|
+
"""
|
|
208
|
+
HSTUPointwiseAttention + Residual Connection
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
def __init__(
|
|
212
|
+
self,
|
|
213
|
+
d_model: int,
|
|
214
|
+
num_heads: int,
|
|
215
|
+
dropout: float = 0.1,
|
|
216
|
+
use_rab_pos: bool = True,
|
|
217
|
+
rab_num_buckets: int = 32,
|
|
218
|
+
rab_max_distance: int = 128,
|
|
219
|
+
):
|
|
220
|
+
super().__init__()
|
|
221
|
+
self.attn = HSTUPointwiseAttention(d_model=d_model, num_heads=num_heads, dropout=dropout)
|
|
222
|
+
self.dropout = nn.Dropout(dropout)
|
|
223
|
+
self.use_rab_pos = use_rab_pos
|
|
224
|
+
self.rel_pos_bias = (RelativePositionBias(num_heads=num_heads, num_buckets=rab_num_buckets, max_distance=rab_max_distance) if use_rab_pos else None)
|
|
225
|
+
|
|
226
|
+
def forward(
|
|
227
|
+
self,
|
|
228
|
+
x: torch.Tensor,
|
|
229
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
230
|
+
key_padding_mask: Optional[torch.Tensor] = None,
|
|
231
|
+
) -> torch.Tensor:
|
|
232
|
+
"""
|
|
233
|
+
x: [B, T, D]
|
|
234
|
+
"""
|
|
235
|
+
B, T, D = x.shape
|
|
236
|
+
device = x.device
|
|
237
|
+
rab = None
|
|
238
|
+
if self.use_rab_pos:
|
|
239
|
+
rab = self.rel_pos_bias(seq_len=T, device=device) # [1, H, T, T]
|
|
240
|
+
out = self.attn(x=x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, rab=rab)
|
|
241
|
+
return x + self.dropout(out)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
class HSTU(BaseModel):
|
|
245
|
+
"""
|
|
246
|
+
HSTU encoder for next-item prediction in a causal, generative setup.
|
|
247
|
+
Pipeline:
|
|
248
|
+
1) Embed tokens + positions from the behavior history
|
|
249
|
+
2) Apply stacked HSTU layers with causal mask and optional RAB
|
|
250
|
+
3) Use the last valid position to produce next-item logits via tied LM head
|
|
251
|
+
"""
|
|
252
|
+
|
|
253
|
+
@property
|
|
254
|
+
def model_name(self) -> str:
|
|
255
|
+
return "HSTU"
|
|
256
|
+
|
|
257
|
+
@property
|
|
258
|
+
def task_type(self) -> str:
|
|
259
|
+
return "multiclass"
|
|
260
|
+
|
|
261
|
+
def __init__(
|
|
262
|
+
self,
|
|
263
|
+
sequence_features: list[SequenceFeature],
|
|
264
|
+
dense_features: Optional[list[DenseFeature]] = None,
|
|
265
|
+
sparse_features: Optional[list[SparseFeature]] = None,
|
|
266
|
+
d_model: Optional[int] = None,
|
|
267
|
+
num_heads: int = 8,
|
|
268
|
+
num_layers: int = 4,
|
|
269
|
+
max_seq_len: int = 200,
|
|
270
|
+
dropout: float = 0.1,
|
|
271
|
+
# RAB settings
|
|
272
|
+
use_rab_pos: bool = True,
|
|
273
|
+
rab_num_buckets: int = 32,
|
|
274
|
+
rab_max_distance: int = 128,
|
|
275
|
+
|
|
276
|
+
tie_embeddings: bool = True,
|
|
277
|
+
target: Optional[list[str] | str] = None,
|
|
278
|
+
optimizer: str = "adam",
|
|
279
|
+
optimizer_params: Optional[dict] = None,
|
|
280
|
+
scheduler: Optional[str] = None,
|
|
281
|
+
scheduler_params: Optional[dict] = None,
|
|
282
|
+
loss_params: Optional[dict] = None,
|
|
283
|
+
embedding_l1_reg: float = 0.0,
|
|
284
|
+
dense_l1_reg: float = 0.0,
|
|
285
|
+
embedding_l2_reg: float = 0.0,
|
|
286
|
+
dense_l2_reg: float = 0.0,
|
|
287
|
+
device: str = "cpu",
|
|
288
|
+
**kwargs,
|
|
289
|
+
):
|
|
290
|
+
if not sequence_features:
|
|
291
|
+
raise ValueError("[HSTU Error] HSTU requires at least one SequenceFeature (user behavior history).")
|
|
292
|
+
|
|
293
|
+
# demo version: use the first SequenceFeature as the main sequence
|
|
294
|
+
self.history_feature = sequence_features[0]
|
|
295
|
+
|
|
296
|
+
hidden_dim = d_model or max(int(getattr(self.history_feature, "embedding_dim", 0) or 0), 32)
|
|
297
|
+
# Make hidden_dim divisible by num_heads
|
|
298
|
+
if hidden_dim % num_heads != 0:
|
|
299
|
+
hidden_dim = num_heads * math.ceil(hidden_dim / num_heads)
|
|
300
|
+
|
|
301
|
+
self.padding_idx = self.history_feature.padding_idx if self.history_feature.padding_idx is not None else 0
|
|
302
|
+
self.vocab_size = self.history_feature.vocab_size
|
|
303
|
+
self.max_seq_len = max_seq_len
|
|
304
|
+
|
|
305
|
+
super().__init__(
|
|
306
|
+
dense_features=dense_features,
|
|
307
|
+
sparse_features=sparse_features,
|
|
308
|
+
sequence_features=sequence_features,
|
|
309
|
+
target=target,
|
|
310
|
+
task=self.task_type,
|
|
311
|
+
device=device,
|
|
312
|
+
embedding_l1_reg=embedding_l1_reg,
|
|
313
|
+
dense_l1_reg=dense_l1_reg,
|
|
314
|
+
embedding_l2_reg=embedding_l2_reg,
|
|
315
|
+
dense_l2_reg=dense_l2_reg,
|
|
316
|
+
**kwargs,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
# token & position embedding (paper usually includes pos embedding / RAB in encoder)
|
|
320
|
+
self.token_embedding = nn.Embedding(
|
|
321
|
+
num_embeddings=self.vocab_size,
|
|
322
|
+
embedding_dim=hidden_dim,
|
|
323
|
+
padding_idx=self.padding_idx,
|
|
324
|
+
)
|
|
325
|
+
self.position_embedding = nn.Embedding(max_seq_len, hidden_dim)
|
|
326
|
+
self.input_dropout = nn.Dropout(dropout)
|
|
327
|
+
|
|
328
|
+
# HSTU layers
|
|
329
|
+
self.layers = nn.ModuleList([HSTULayer(d_model=hidden_dim, num_heads=num_heads, dropout=dropout, use_rab_pos=use_rab_pos,
|
|
330
|
+
rab_num_buckets=rab_num_buckets, rab_max_distance=rab_max_distance) for _ in range(num_layers)])
|
|
331
|
+
|
|
332
|
+
self.final_norm = nn.LayerNorm(hidden_dim)
|
|
333
|
+
self.lm_head = nn.Linear(hidden_dim, self.vocab_size, bias=False)
|
|
334
|
+
if tie_embeddings:
|
|
335
|
+
self.lm_head.weight = self.token_embedding.weight
|
|
336
|
+
|
|
337
|
+
# causal mask buffer
|
|
338
|
+
self.register_buffer("causal_mask", torch.empty(0), persistent=False)
|
|
339
|
+
self.ignore_index = self.padding_idx if self.padding_idx is not None else -100
|
|
340
|
+
|
|
341
|
+
optimizer_params = optimizer_params or {}
|
|
342
|
+
scheduler_params = scheduler_params or {}
|
|
343
|
+
loss_params = loss_params or {}
|
|
344
|
+
loss_params.setdefault("ignore_index", self.ignore_index)
|
|
345
|
+
|
|
346
|
+
self.compile(optimizer=optimizer, optimizer_params=optimizer_params, scheduler=scheduler, scheduler_params=scheduler_params, loss="crossentropy", loss_params=loss_params)
|
|
347
|
+
self._register_regularization_weights(embedding_attr="token_embedding", include_modules=["layers", "lm_head"])
|
|
348
|
+
|
|
349
|
+
def _build_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
|
350
|
+
"""
|
|
351
|
+
build causal mask of shape [T, T]: upper triangle is -inf, others are 0.
|
|
352
|
+
This will be added to the logits to simulate causal structure.
|
|
353
|
+
"""
|
|
354
|
+
if self.causal_mask.numel() == 0 or self.causal_mask.size(0) < seq_len:
|
|
355
|
+
mask = torch.full((seq_len, seq_len), float("-inf"), device=device)
|
|
356
|
+
mask = torch.triu(mask, diagonal=1)
|
|
357
|
+
self.causal_mask = mask
|
|
358
|
+
return self.causal_mask[:seq_len, :seq_len]
|
|
359
|
+
|
|
360
|
+
def _trim_sequence(self, seq: torch.Tensor) -> torch.Tensor:
|
|
361
|
+
if seq.size(1) <= self.max_seq_len:
|
|
362
|
+
return seq
|
|
363
|
+
return seq[:, -self.max_seq_len :]
|
|
364
|
+
|
|
365
|
+
def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
366
|
+
seq = x[self.history_feature.name].long() # [B, T_raw]
|
|
367
|
+
seq = self._trim_sequence(seq) # [B, T]
|
|
368
|
+
|
|
369
|
+
B, T = seq.shape
|
|
370
|
+
device = seq.device
|
|
371
|
+
# position ids: [B, T]
|
|
372
|
+
pos_ids = torch.arange(T, device=device).unsqueeze(0).expand(B, -1)
|
|
373
|
+
token_emb = self.token_embedding(seq) # [B, T, D]
|
|
374
|
+
pos_emb = self.position_embedding(pos_ids) # [B, T, D]
|
|
375
|
+
hidden_states = self.input_dropout(token_emb + pos_emb)
|
|
376
|
+
|
|
377
|
+
# padding mask:True = pad
|
|
378
|
+
padding_mask = seq.eq(self.padding_idx) # [B, T]
|
|
379
|
+
attn_mask = self._build_causal_mask(seq_len=T, device=device) # [T, T]
|
|
380
|
+
|
|
381
|
+
for layer in self.layers:
|
|
382
|
+
hidden_states = layer(x=hidden_states, attn_mask=attn_mask, key_padding_mask=padding_mask)
|
|
383
|
+
hidden_states = self.final_norm(hidden_states) # [B, T, D]
|
|
384
|
+
|
|
385
|
+
valid_lengths = (~padding_mask).sum(dim=1) # [B]
|
|
386
|
+
last_index = (valid_lengths - 1).clamp(min=0)
|
|
387
|
+
last_hidden = hidden_states[torch.arange(B, device=device), last_index] # [B, D]
|
|
388
|
+
|
|
389
|
+
logits = self.lm_head(last_hidden) # [B, vocab_size]
|
|
390
|
+
return logits
|
|
391
|
+
|
|
392
|
+
def compute_loss(self, y_pred, y_true):
|
|
393
|
+
"""
|
|
394
|
+
y_true: [B] or [B, 1], the id of the next item.
|
|
395
|
+
"""
|
|
396
|
+
if y_true is None:
|
|
397
|
+
raise ValueError("[HSTU-compute_loss] Training requires y_true (next item id).")
|
|
398
|
+
labels = y_true.view(-1).long()
|
|
399
|
+
return self.loss_fn[0](y_pred, labels)
|
nextrec/models/match/mind.py
CHANGED
|
@@ -13,7 +13,116 @@ from typing import Literal
|
|
|
13
13
|
|
|
14
14
|
from nextrec.basic.model import BaseMatchModel
|
|
15
15
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
16
|
-
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
16
|
+
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
17
|
+
|
|
18
|
+
class MultiInterestSA(nn.Module):
|
|
19
|
+
"""Multi-interest self-attention extractor from MIND (Li et al., 2019)."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, embedding_dim, interest_num, hidden_dim=None):
|
|
22
|
+
super(MultiInterestSA, self).__init__()
|
|
23
|
+
self.embedding_dim = embedding_dim
|
|
24
|
+
self.interest_num = interest_num
|
|
25
|
+
if hidden_dim == None:
|
|
26
|
+
self.hidden_dim = self.embedding_dim * 4
|
|
27
|
+
self.W1 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.hidden_dim), requires_grad=True)
|
|
28
|
+
self.W2 = torch.nn.Parameter(torch.rand(self.hidden_dim, self.interest_num), requires_grad=True)
|
|
29
|
+
self.W3 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.embedding_dim), requires_grad=True)
|
|
30
|
+
|
|
31
|
+
def forward(self, seq_emb, mask=None):
|
|
32
|
+
H = torch.einsum('bse, ed -> bsd', seq_emb, self.W1).tanh()
|
|
33
|
+
if mask != None:
|
|
34
|
+
A = torch.einsum('bsd, dk -> bsk', H, self.W2) + -1.e9 * (1 - mask.float())
|
|
35
|
+
A = F.softmax(A, dim=1)
|
|
36
|
+
else:
|
|
37
|
+
A = F.softmax(torch.einsum('bsd, dk -> bsk', H, self.W2), dim=1)
|
|
38
|
+
A = A.permute(0, 2, 1)
|
|
39
|
+
multi_interest_emb = torch.matmul(A, seq_emb)
|
|
40
|
+
return multi_interest_emb
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class CapsuleNetwork(nn.Module):
|
|
44
|
+
"""Dynamic routing capsule network used in MIND (Li et al., 2019)."""
|
|
45
|
+
|
|
46
|
+
def __init__(self, embedding_dim, seq_len, bilinear_type=2, interest_num=4, routing_times=3, relu_layer=False):
|
|
47
|
+
super(CapsuleNetwork, self).__init__()
|
|
48
|
+
self.embedding_dim = embedding_dim # h
|
|
49
|
+
self.seq_len = seq_len # s
|
|
50
|
+
self.bilinear_type = bilinear_type
|
|
51
|
+
self.interest_num = interest_num
|
|
52
|
+
self.routing_times = routing_times
|
|
53
|
+
|
|
54
|
+
self.relu_layer = relu_layer
|
|
55
|
+
self.stop_grad = True
|
|
56
|
+
self.relu = nn.Sequential(nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), nn.ReLU())
|
|
57
|
+
if self.bilinear_type == 0: # MIND
|
|
58
|
+
self.linear = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
|
|
59
|
+
elif self.bilinear_type == 1:
|
|
60
|
+
self.linear = nn.Linear(self.embedding_dim, self.embedding_dim * self.interest_num, bias=False)
|
|
61
|
+
else:
|
|
62
|
+
self.w = nn.Parameter(torch.Tensor(1, self.seq_len, self.interest_num * self.embedding_dim, self.embedding_dim))
|
|
63
|
+
nn.init.xavier_uniform_(self.w)
|
|
64
|
+
|
|
65
|
+
def forward(self, item_eb, mask):
|
|
66
|
+
if self.bilinear_type == 0:
|
|
67
|
+
item_eb_hat = self.linear(item_eb)
|
|
68
|
+
item_eb_hat = item_eb_hat.repeat(1, 1, self.interest_num)
|
|
69
|
+
elif self.bilinear_type == 1:
|
|
70
|
+
item_eb_hat = self.linear(item_eb)
|
|
71
|
+
else:
|
|
72
|
+
u = torch.unsqueeze(item_eb, dim=2)
|
|
73
|
+
item_eb_hat = torch.sum(self.w[:, :self.seq_len, :, :] * u, dim=3)
|
|
74
|
+
|
|
75
|
+
item_eb_hat = torch.reshape(item_eb_hat, (-1, self.seq_len, self.interest_num, self.embedding_dim))
|
|
76
|
+
item_eb_hat = torch.transpose(item_eb_hat, 1, 2).contiguous()
|
|
77
|
+
item_eb_hat = torch.reshape(item_eb_hat, (-1, self.interest_num, self.seq_len, self.embedding_dim))
|
|
78
|
+
|
|
79
|
+
if self.stop_grad:
|
|
80
|
+
item_eb_hat_iter = item_eb_hat.detach()
|
|
81
|
+
else:
|
|
82
|
+
item_eb_hat_iter = item_eb_hat
|
|
83
|
+
|
|
84
|
+
if self.bilinear_type > 0:
|
|
85
|
+
capsule_weight = torch.zeros(item_eb_hat.shape[0],
|
|
86
|
+
self.interest_num,
|
|
87
|
+
self.seq_len,
|
|
88
|
+
device=item_eb.device,
|
|
89
|
+
requires_grad=False)
|
|
90
|
+
else:
|
|
91
|
+
capsule_weight = torch.randn(item_eb_hat.shape[0],
|
|
92
|
+
self.interest_num,
|
|
93
|
+
self.seq_len,
|
|
94
|
+
device=item_eb.device,
|
|
95
|
+
requires_grad=False)
|
|
96
|
+
|
|
97
|
+
for i in range(self.routing_times): # 动态路由传播3次
|
|
98
|
+
atten_mask = torch.unsqueeze(mask, 1).repeat(1, self.interest_num, 1)
|
|
99
|
+
paddings = torch.zeros_like(atten_mask, dtype=torch.float)
|
|
100
|
+
|
|
101
|
+
capsule_softmax_weight = F.softmax(capsule_weight, dim=-1)
|
|
102
|
+
capsule_softmax_weight = torch.where(torch.eq(atten_mask, 0), paddings, capsule_softmax_weight)
|
|
103
|
+
capsule_softmax_weight = torch.unsqueeze(capsule_softmax_weight, 2)
|
|
104
|
+
|
|
105
|
+
if i < 2:
|
|
106
|
+
interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat_iter)
|
|
107
|
+
cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
|
|
108
|
+
scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
|
|
109
|
+
interest_capsule = scalar_factor * interest_capsule
|
|
110
|
+
|
|
111
|
+
delta_weight = torch.matmul(item_eb_hat_iter, torch.transpose(interest_capsule, 2, 3).contiguous())
|
|
112
|
+
delta_weight = torch.reshape(delta_weight, (-1, self.interest_num, self.seq_len))
|
|
113
|
+
capsule_weight = capsule_weight + delta_weight
|
|
114
|
+
else:
|
|
115
|
+
interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat)
|
|
116
|
+
cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
|
|
117
|
+
scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
|
|
118
|
+
interest_capsule = scalar_factor * interest_capsule
|
|
119
|
+
|
|
120
|
+
interest_capsule = torch.reshape(interest_capsule, (-1, self.interest_num, self.embedding_dim))
|
|
121
|
+
|
|
122
|
+
if self.relu_layer:
|
|
123
|
+
interest_capsule = self.relu(interest_capsule)
|
|
124
|
+
|
|
125
|
+
return interest_capsule
|
|
17
126
|
|
|
18
127
|
|
|
19
128
|
class MIND(BaseMatchModel):
|
|
@@ -1,7 +1,44 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 09/11/2025
|
|
3
|
+
Checkpoint: edit on 29/11/2025
|
|
3
4
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
4
|
-
Reference:
|
|
5
|
+
Reference:
|
|
6
|
+
[1] Ma X, Zhao L, Huang G, et al. Entire space multi-task model: An effective approach
|
|
7
|
+
for estimating post-click conversion rate[C]//SIGIR. 2018: 1137-1140.
|
|
8
|
+
(https://dl.acm.org/doi/10.1145/3209978.3210007)
|
|
9
|
+
|
|
10
|
+
Entire Space Multi-task Model (ESMM) targets CVR estimation by jointly optimizing
|
|
11
|
+
CTR and CTCVR on the full impression space, mitigating sample selection bias and
|
|
12
|
+
conversion sparsity. CTR predicts P(click | impression), CVR predicts P(conversion |
|
|
13
|
+
click), and their product forms CTCVR supervised on impression labels.
|
|
14
|
+
|
|
15
|
+
Workflow:
|
|
16
|
+
(1) Shared embeddings encode all features from impressions
|
|
17
|
+
(2) CTR tower outputs click probability conditioned on impression
|
|
18
|
+
(3) CVR tower outputs conversion probability conditioned on click
|
|
19
|
+
(4) CTCVR = CTR * CVR enables end-to-end training without filtering clicked data
|
|
20
|
+
|
|
21
|
+
Key Advantages:
|
|
22
|
+
- Trains on the entire impression space to remove selection bias
|
|
23
|
+
- Transfers rich click signals to sparse conversion prediction via shared embeddings
|
|
24
|
+
- Stable optimization by decomposing CTCVR into well-defined sub-tasks
|
|
25
|
+
- Simple architecture that can pair with other multi-task variants
|
|
26
|
+
|
|
27
|
+
ESMM(Entire Space Multi-task Model)用于 CVR 预估,通过在曝光全空间联合训练
|
|
28
|
+
CTR 与 CTCVR,缓解样本选择偏差和转化数据稀疏问题。CTR 预测 P(click|impression),
|
|
29
|
+
CVR 预测 P(conversion|click),二者相乘得到 CTCVR 并在曝光标签上直接监督。
|
|
30
|
+
|
|
31
|
+
流程:
|
|
32
|
+
(1) 共享 embedding 统一处理曝光特征
|
|
33
|
+
(2) CTR 塔输出曝光下的点击概率
|
|
34
|
+
(3) CVR 塔输出点击后的转化概率
|
|
35
|
+
(4) CTR 与 CVR 相乘得到 CTCVR,无需只在点击子集上训练
|
|
36
|
+
|
|
37
|
+
主要优点:
|
|
38
|
+
- 在曝光空间训练,避免样本选择偏差
|
|
39
|
+
- 通过共享表示将点击信号迁移到稀疏的转化任务
|
|
40
|
+
- 将 CTCVR 分解为子任务,优化稳定
|
|
41
|
+
- 结构简单,可与其它多任务方法组合使用
|
|
5
42
|
"""
|
|
6
43
|
|
|
7
44
|
import torch
|
|
@@ -77,37 +114,22 @@ class ESMM(BaseModel):
|
|
|
77
114
|
|
|
78
115
|
# All features
|
|
79
116
|
self.all_features = dense_features + sparse_features + sequence_features
|
|
80
|
-
|
|
81
117
|
# Shared embedding layer
|
|
82
118
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
119
|
+
input_dim = self.embedding.input_dim # Calculate input dimension, better way than below
|
|
120
|
+
# emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
|
|
121
|
+
# dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
|
|
122
|
+
# input_dim = emb_dim_total + dense_input_dim
|
|
83
123
|
|
|
84
|
-
# Calculate input dimension
|
|
85
|
-
emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
|
|
86
|
-
dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
|
|
87
|
-
input_dim = emb_dim_total + dense_input_dim
|
|
88
|
-
|
|
89
124
|
# CTR tower
|
|
90
125
|
self.ctr_tower = MLP(input_dim=input_dim, output_layer=True, **ctr_params)
|
|
91
126
|
|
|
92
127
|
# CVR tower
|
|
93
128
|
self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
|
|
94
|
-
self.prediction_layer = PredictionLayer(
|
|
95
|
-
task_type=self.task_type,
|
|
96
|
-
task_dims=[1, 1]
|
|
97
|
-
)
|
|
98
|
-
|
|
129
|
+
self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1, 1])
|
|
99
130
|
# Register regularization weights
|
|
100
|
-
self._register_regularization_weights(
|
|
101
|
-
|
|
102
|
-
include_modules=['ctr_tower', 'cvr_tower']
|
|
103
|
-
)
|
|
104
|
-
|
|
105
|
-
self.compile(
|
|
106
|
-
optimizer=optimizer,
|
|
107
|
-
optimizer_params=optimizer_params,
|
|
108
|
-
loss=loss,
|
|
109
|
-
loss_params=loss_params,
|
|
110
|
-
)
|
|
131
|
+
self._register_regularization_weights(embedding_attr='embedding', include_modules=['ctr_tower', 'cvr_tower'])
|
|
132
|
+
self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
|
|
111
133
|
|
|
112
134
|
def forward(self, x):
|
|
113
135
|
# Get all embeddings and flatten
|
|
@@ -119,11 +141,8 @@ class ESMM(BaseModel):
|
|
|
119
141
|
logits = torch.cat([ctr_logit, cvr_logit], dim=1)
|
|
120
142
|
preds = self.prediction_layer(logits)
|
|
121
143
|
ctr, cvr = preds.chunk(2, dim=1)
|
|
122
|
-
|
|
123
|
-
# CTCVR prediction: P(click & conversion | impression) = P(click) * P(conversion | click)
|
|
124
144
|
ctcvr = ctr * cvr # [B, 1]
|
|
125
145
|
|
|
126
|
-
# Output: [CTR, CTCVR]
|
|
127
|
-
# Note: We supervise CTR with click labels and CTCVR with conversion labels
|
|
146
|
+
# Output: [CTR, CTCVR], We supervise CTR with click labels and CTCVR with conversion labels
|
|
128
147
|
y = torch.cat([ctr, ctcvr], dim=1) # [B, 2]
|
|
129
148
|
return y # [B, 2], where y[:, 0] is CTR and y[:, 1] is CTCVR
|