nextrec 0.2.7__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/activation.py +4 -8
  3. nextrec/basic/callback.py +1 -1
  4. nextrec/basic/features.py +33 -25
  5. nextrec/basic/layers.py +164 -601
  6. nextrec/basic/loggers.py +4 -5
  7. nextrec/basic/metrics.py +39 -115
  8. nextrec/basic/model.py +257 -177
  9. nextrec/basic/session.py +1 -5
  10. nextrec/data/__init__.py +12 -0
  11. nextrec/data/data_utils.py +3 -27
  12. nextrec/data/dataloader.py +26 -34
  13. nextrec/data/preprocessor.py +2 -1
  14. nextrec/loss/listwise.py +6 -4
  15. nextrec/loss/loss_utils.py +10 -6
  16. nextrec/loss/pairwise.py +5 -3
  17. nextrec/loss/pointwise.py +7 -13
  18. nextrec/models/generative/__init__.py +5 -0
  19. nextrec/models/generative/hstu.py +399 -0
  20. nextrec/models/match/mind.py +110 -1
  21. nextrec/models/multi_task/esmm.py +46 -27
  22. nextrec/models/multi_task/mmoe.py +48 -30
  23. nextrec/models/multi_task/ple.py +156 -141
  24. nextrec/models/multi_task/poso.py +413 -0
  25. nextrec/models/multi_task/share_bottom.py +43 -26
  26. nextrec/models/ranking/__init__.py +2 -0
  27. nextrec/models/ranking/dcn.py +20 -1
  28. nextrec/models/ranking/dcn_v2.py +84 -0
  29. nextrec/models/ranking/deepfm.py +44 -18
  30. nextrec/models/ranking/dien.py +130 -27
  31. nextrec/models/ranking/masknet.py +13 -67
  32. nextrec/models/ranking/widedeep.py +39 -18
  33. nextrec/models/ranking/xdeepfm.py +34 -1
  34. nextrec/utils/common.py +26 -1
  35. nextrec/utils/optimizer.py +7 -3
  36. nextrec-0.3.2.dist-info/METADATA +312 -0
  37. nextrec-0.3.2.dist-info/RECORD +57 -0
  38. nextrec-0.2.7.dist-info/METADATA +0 -281
  39. nextrec-0.2.7.dist-info/RECORD +0 -54
  40. {nextrec-0.2.7.dist-info → nextrec-0.3.2.dist-info}/WHEEL +0 -0
  41. {nextrec-0.2.7.dist-info → nextrec-0.3.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,399 @@
1
+ """
2
+ [Info: this version is not released yet, i need to more research on source code and paper]
3
+ Date: create on 01/12/2025
4
+ Checkpoint: edit on 01/12/2025
5
+ Author: Yang Zhou, zyaztec@gmail.com
6
+ Reference:
7
+ [1] Meta AI. Generative Recommenders (HSTU encoder) — https://github.com/meta-recsys/generative-recommenders
8
+ [2] Ma W, Li P, Chen C, et al. Actions speak louder than words: Trillion-parameter sequential transducers for generative recommendations. arXiv:2402.17152.
9
+
10
+ Hierarchical Sequential Transduction Unit (HSTU) is the core encoder behind
11
+ Meta’s Generative Recommenders. It replaces softmax attention with lightweight
12
+ pointwise activations, enabling extremely deep stacks on long behavior sequences.
13
+
14
+ In each HSTU layer:
15
+ (1) Tokens are projected into four streams U, V, Q, K via a shared feed-forward block
16
+ (2) Softmax-free interactions combine QK^T with Relative Attention Bias (RAB) to encode distance
17
+ (3) Aggregated context is modulated by U-gating and mapped back through an output projection
18
+
19
+ Stacking layers yields an efficient causal encoder for next-item
20
+ generation. With a tied-embedding LM head, HSTU forms
21
+ a full generative recommendation model.
22
+
23
+ Key Advantages:
24
+ - Softmax-free attention scales better on deep/long sequences
25
+ - RAB captures temporal structure without extra attention heads
26
+ - Causal masking and padding-aware normalization fit real logs
27
+ - Weight tying reduces parameters and stabilizes training
28
+ - Serves as a drop-in backbone for generative recommendation
29
+
30
+ HSTU(层次化序列转导单元)是 Meta 生成式推荐的核心编码器,
31
+ 用点式激活替代 softmax 注意力,可在长序列上轻松堆叠深层结构。
32
+
33
+ 单层 HSTU 的主要步骤:
34
+ (1) 将输入一次性映射到 U、V、Q、K 四条通路
35
+ (2) 利用不含 softmax 的 QK^T 结合相对位置偏置(RAB)建模距离信息
36
+ (3) 用 U 对聚合上下文进行门控,再映射回输出空间
37
+
38
+ 多层堆叠后,可得到高效的因果编码器;与绑权 LM 头配合即可完成 next-item 预测。
39
+
40
+ 主要优势:
41
+ - 摆脱 softmax,在长序列、深层模型上更易扩展
42
+ - 相对位置偏置稳健刻画时序结构
43
+ - 因果 mask 与 padding 感知归一化贴合真实日志
44
+ - 绑权输出头降低参数量并提升训练稳定性
45
+ - 直接作为生成式推荐的骨干网络
46
+ """
47
+
48
+ from __future__ import annotations
49
+
50
+ import math
51
+ from typing import Optional
52
+
53
+ import torch
54
+ import torch.nn as nn
55
+ import torch.nn.functional as F
56
+
57
+ from nextrec.basic.model import BaseModel
58
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
59
+
60
+
61
+ def _relative_position_bucket(
62
+ relative_position: torch.Tensor,
63
+ num_buckets: int = 32,
64
+ max_distance: int = 128,
65
+ ) -> torch.Tensor:
66
+ """
67
+ map the relative position (i-j) to a bucket in [0, num_buckets).
68
+ """
69
+ # only need the negative part for causal attention
70
+ n = -relative_position
71
+ n = torch.clamp(n, min=0)
72
+
73
+ # when the distance is small, keep it exact
74
+ max_exact = num_buckets // 2
75
+ is_small = n < max_exact
76
+
77
+ # when the distance is too far, do log scaling
78
+ large_val = max_exact + ((torch.log(n.float() / max_exact + 1e-6) / math.log(max_distance / max_exact)) * (num_buckets - max_exact)).long()
79
+ large_val = torch.clamp(large_val, max=num_buckets - 1)
80
+
81
+ buckets = torch.where(is_small, n.long(), large_val)
82
+ return buckets
83
+
84
+
85
+ class RelativePositionBias(nn.Module):
86
+ """
87
+ Compute relative position bias (RAB) for HSTU attention.
88
+ The input is the sequence length T, output is [1, num_heads, seq_len, seq_len].
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ num_heads: int,
94
+ num_buckets: int = 32,
95
+ max_distance: int = 128,
96
+ ):
97
+ super().__init__()
98
+ self.num_heads = num_heads
99
+ self.num_buckets = num_buckets
100
+ self.max_distance = max_distance
101
+ self.embedding = nn.Embedding(num_buckets, num_heads)
102
+
103
+ def forward(self, seq_len: int, device: torch.device) -> torch.Tensor:
104
+ # positions: [T]
105
+ ctx = torch.arange(seq_len, device=device)[:, None]
106
+ mem = torch.arange(seq_len, device=device)[None, :]
107
+ rel_pos = mem - ctx # a matrix to describe all relative positions for each [i,j] pair, shape = [seq_len, seq_len]
108
+ buckets = _relative_position_bucket(rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance,) # map to buckets
109
+ values = self.embedding(buckets) # embedding vector for each [i,j] pair, shape = [seq_len, seq_len, embedding_dim=num_heads]
110
+ return values.permute(2, 0, 1).unsqueeze(0) # [1, num_heads, seq_len, seq_len]
111
+
112
+ class HSTUPointwiseAttention(nn.Module):
113
+ """
114
+ Pointwise aggregation attention that implements HSTU without softmax:
115
+ 1) [U, V, Q, K] = split( φ1(f1(X)) ), U: gate, V: value, Q: query, K: key
116
+ 2) AV = φ2(QK^T + rab) V / N, av is attention-weighted value
117
+ 3) Y = f2( Norm(AV) ⊙ U ), y is output
118
+ φ1, φ2 use SiLU; Norm uses LayerNorm.
119
+ """
120
+
121
+ def __init__(
122
+ self,
123
+ d_model: int,
124
+ num_heads: int,
125
+ dropout: float = 0.1,
126
+ alpha: float | None = None
127
+ ):
128
+ super().__init__()
129
+ if d_model % num_heads != 0:
130
+ raise ValueError(f"[HSTUPointwiseAttention Error] d_model({d_model}) % num_heads({num_heads}) != 0")
131
+
132
+ self.d_model = d_model
133
+ self.num_heads = num_heads
134
+ self.d_head = d_model // num_heads
135
+ self.alpha = alpha if alpha is not None else (self.d_head ** -0.5)
136
+ # project input to 4 * d_model for U, V, Q, K
137
+ self.in_proj = nn.Linear(d_model, 4 * d_model, bias=True)
138
+ # project output back to d_model
139
+ self.out_proj = nn.Linear(d_model, d_model, bias=True)
140
+ self.dropout = nn.Dropout(dropout)
141
+ self.norm = nn.LayerNorm(d_model)
142
+
143
+ def _reshape_heads(self, x: torch.Tensor) -> torch.Tensor:
144
+ """
145
+ [B, T, D] -> [B, H, T, d_head]
146
+ """
147
+ B, T, D = x.shape
148
+ return x.view(B, T, self.num_heads, self.d_head).transpose(1, 2)
149
+
150
+ def forward(
151
+ self,
152
+ x: torch.Tensor,
153
+ attn_mask: Optional[torch.Tensor] = None, # [T, T] with 0 or -inf
154
+ key_padding_mask: Optional[torch.Tensor] = None, # [B, T], True = pad
155
+ rab: Optional[torch.Tensor] = None, # [1, H, T, T] or None
156
+ ) -> torch.Tensor:
157
+ B, T, D = x.shape
158
+
159
+ # Eq.(1): [U, V, Q, K] = split( φ1(f1(X)) )
160
+ h = F.silu(self.in_proj(x)) # [B, T, 4D]
161
+ U, V, Q, K = h.chunk(4, dim=-1) # each [B, T, D]
162
+
163
+ Qh = self._reshape_heads(Q) # [B, H, T, d_head]
164
+ Kh = self._reshape_heads(K) # [B, H, T, d_head]
165
+ Vh = self._reshape_heads(V) # [B, H, T, d_head]
166
+ Uh = self._reshape_heads(U) # [B, H, T, d_head]
167
+
168
+ # attention logits: QK^T (without 1/sqrt(d) and softmax)
169
+ logits = torch.matmul(Qh, Kh.transpose(-2, -1)) * self.alpha # [B, H, T, T]
170
+
171
+ # add relative position bias (rab^p), and future extensible rab^t
172
+ if rab is not None:
173
+ # rab: [1, H, T, T] or [B, H, T, T]
174
+ logits = logits + rab
175
+
176
+ # construct an "allowed" mask to calculate N
177
+ # 1 indicates that the (query i, key j) pair is a valid attention pair; 0 indicates it is masked out
178
+ allowed = torch.ones_like(logits, dtype=torch.float) # [B, H, T, T]
179
+
180
+ # causal mask: attn_mask is usually an upper triangular matrix of -inf with shape [T, T]
181
+ if attn_mask is not None:
182
+ allowed = allowed * (attn_mask.view(1, 1, T, T) == 0).float()
183
+ logits = logits + attn_mask.view(1, 1, T, T)
184
+
185
+ # padding mask: key_padding_mask is usually [B, T], True = pad
186
+ if key_padding_mask is not None:
187
+ # valid: 1 for non-pad, 0 for pad
188
+ valid = (~key_padding_mask).float() # [B, T]
189
+ valid = valid.view(B, 1, 1, T) # [B, 1, 1, T]
190
+ allowed = allowed * valid
191
+ logits = logits.masked_fill(valid == 0, float("-inf"))
192
+
193
+ # Eq.(2): A(X)V(X) = φ2(QK^T + rab) V(X) / N
194
+ attn = F.silu(logits) # [B, H, T, T]
195
+ denom = allowed.sum(dim=-1, keepdim=True) # [B, H, T, 1]
196
+ denom = denom.clamp(min=1.0)
197
+
198
+ attn = attn / denom # [B, H, T, T]
199
+ AV = torch.matmul(attn, Vh) # [B, H, T, d_head]
200
+ AV = AV.transpose(1, 2).contiguous().view(B, T, D) # reshape back to [B, T, D]
201
+ U_flat = Uh.transpose(1, 2).contiguous().view(B, T, D)
202
+ y = self.out_proj(self.dropout(self.norm(AV) * U_flat)) # [B, T, D]
203
+ return y
204
+
205
+
206
+ class HSTULayer(nn.Module):
207
+ """
208
+ HSTUPointwiseAttention + Residual Connection
209
+ """
210
+
211
+ def __init__(
212
+ self,
213
+ d_model: int,
214
+ num_heads: int,
215
+ dropout: float = 0.1,
216
+ use_rab_pos: bool = True,
217
+ rab_num_buckets: int = 32,
218
+ rab_max_distance: int = 128,
219
+ ):
220
+ super().__init__()
221
+ self.attn = HSTUPointwiseAttention(d_model=d_model, num_heads=num_heads, dropout=dropout)
222
+ self.dropout = nn.Dropout(dropout)
223
+ self.use_rab_pos = use_rab_pos
224
+ self.rel_pos_bias = (RelativePositionBias(num_heads=num_heads, num_buckets=rab_num_buckets, max_distance=rab_max_distance) if use_rab_pos else None)
225
+
226
+ def forward(
227
+ self,
228
+ x: torch.Tensor,
229
+ attn_mask: Optional[torch.Tensor] = None,
230
+ key_padding_mask: Optional[torch.Tensor] = None,
231
+ ) -> torch.Tensor:
232
+ """
233
+ x: [B, T, D]
234
+ """
235
+ B, T, D = x.shape
236
+ device = x.device
237
+ rab = None
238
+ if self.use_rab_pos:
239
+ rab = self.rel_pos_bias(seq_len=T, device=device) # [1, H, T, T]
240
+ out = self.attn(x=x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, rab=rab)
241
+ return x + self.dropout(out)
242
+
243
+
244
+ class HSTU(BaseModel):
245
+ """
246
+ HSTU encoder for next-item prediction in a causal, generative setup.
247
+ Pipeline:
248
+ 1) Embed tokens + positions from the behavior history
249
+ 2) Apply stacked HSTU layers with causal mask and optional RAB
250
+ 3) Use the last valid position to produce next-item logits via tied LM head
251
+ """
252
+
253
+ @property
254
+ def model_name(self) -> str:
255
+ return "HSTU"
256
+
257
+ @property
258
+ def task_type(self) -> str:
259
+ return "multiclass"
260
+
261
+ def __init__(
262
+ self,
263
+ sequence_features: list[SequenceFeature],
264
+ dense_features: Optional[list[DenseFeature]] = None,
265
+ sparse_features: Optional[list[SparseFeature]] = None,
266
+ d_model: Optional[int] = None,
267
+ num_heads: int = 8,
268
+ num_layers: int = 4,
269
+ max_seq_len: int = 200,
270
+ dropout: float = 0.1,
271
+ # RAB settings
272
+ use_rab_pos: bool = True,
273
+ rab_num_buckets: int = 32,
274
+ rab_max_distance: int = 128,
275
+
276
+ tie_embeddings: bool = True,
277
+ target: Optional[list[str] | str] = None,
278
+ optimizer: str = "adam",
279
+ optimizer_params: Optional[dict] = None,
280
+ scheduler: Optional[str] = None,
281
+ scheduler_params: Optional[dict] = None,
282
+ loss_params: Optional[dict] = None,
283
+ embedding_l1_reg: float = 0.0,
284
+ dense_l1_reg: float = 0.0,
285
+ embedding_l2_reg: float = 0.0,
286
+ dense_l2_reg: float = 0.0,
287
+ device: str = "cpu",
288
+ **kwargs,
289
+ ):
290
+ if not sequence_features:
291
+ raise ValueError("[HSTU Error] HSTU requires at least one SequenceFeature (user behavior history).")
292
+
293
+ # demo version: use the first SequenceFeature as the main sequence
294
+ self.history_feature = sequence_features[0]
295
+
296
+ hidden_dim = d_model or max(int(getattr(self.history_feature, "embedding_dim", 0) or 0), 32)
297
+ # Make hidden_dim divisible by num_heads
298
+ if hidden_dim % num_heads != 0:
299
+ hidden_dim = num_heads * math.ceil(hidden_dim / num_heads)
300
+
301
+ self.padding_idx = self.history_feature.padding_idx if self.history_feature.padding_idx is not None else 0
302
+ self.vocab_size = self.history_feature.vocab_size
303
+ self.max_seq_len = max_seq_len
304
+
305
+ super().__init__(
306
+ dense_features=dense_features,
307
+ sparse_features=sparse_features,
308
+ sequence_features=sequence_features,
309
+ target=target,
310
+ task=self.task_type,
311
+ device=device,
312
+ embedding_l1_reg=embedding_l1_reg,
313
+ dense_l1_reg=dense_l1_reg,
314
+ embedding_l2_reg=embedding_l2_reg,
315
+ dense_l2_reg=dense_l2_reg,
316
+ **kwargs,
317
+ )
318
+
319
+ # token & position embedding (paper usually includes pos embedding / RAB in encoder)
320
+ self.token_embedding = nn.Embedding(
321
+ num_embeddings=self.vocab_size,
322
+ embedding_dim=hidden_dim,
323
+ padding_idx=self.padding_idx,
324
+ )
325
+ self.position_embedding = nn.Embedding(max_seq_len, hidden_dim)
326
+ self.input_dropout = nn.Dropout(dropout)
327
+
328
+ # HSTU layers
329
+ self.layers = nn.ModuleList([HSTULayer(d_model=hidden_dim, num_heads=num_heads, dropout=dropout, use_rab_pos=use_rab_pos,
330
+ rab_num_buckets=rab_num_buckets, rab_max_distance=rab_max_distance) for _ in range(num_layers)])
331
+
332
+ self.final_norm = nn.LayerNorm(hidden_dim)
333
+ self.lm_head = nn.Linear(hidden_dim, self.vocab_size, bias=False)
334
+ if tie_embeddings:
335
+ self.lm_head.weight = self.token_embedding.weight
336
+
337
+ # causal mask buffer
338
+ self.register_buffer("causal_mask", torch.empty(0), persistent=False)
339
+ self.ignore_index = self.padding_idx if self.padding_idx is not None else -100
340
+
341
+ optimizer_params = optimizer_params or {}
342
+ scheduler_params = scheduler_params or {}
343
+ loss_params = loss_params or {}
344
+ loss_params.setdefault("ignore_index", self.ignore_index)
345
+
346
+ self.compile(optimizer=optimizer, optimizer_params=optimizer_params, scheduler=scheduler, scheduler_params=scheduler_params, loss="crossentropy", loss_params=loss_params)
347
+ self._register_regularization_weights(embedding_attr="token_embedding", include_modules=["layers", "lm_head"])
348
+
349
+ def _build_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
350
+ """
351
+ build causal mask of shape [T, T]: upper triangle is -inf, others are 0.
352
+ This will be added to the logits to simulate causal structure.
353
+ """
354
+ if self.causal_mask.numel() == 0 or self.causal_mask.size(0) < seq_len:
355
+ mask = torch.full((seq_len, seq_len), float("-inf"), device=device)
356
+ mask = torch.triu(mask, diagonal=1)
357
+ self.causal_mask = mask
358
+ return self.causal_mask[:seq_len, :seq_len]
359
+
360
+ def _trim_sequence(self, seq: torch.Tensor) -> torch.Tensor:
361
+ if seq.size(1) <= self.max_seq_len:
362
+ return seq
363
+ return seq[:, -self.max_seq_len :]
364
+
365
+ def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
366
+ seq = x[self.history_feature.name].long() # [B, T_raw]
367
+ seq = self._trim_sequence(seq) # [B, T]
368
+
369
+ B, T = seq.shape
370
+ device = seq.device
371
+ # position ids: [B, T]
372
+ pos_ids = torch.arange(T, device=device).unsqueeze(0).expand(B, -1)
373
+ token_emb = self.token_embedding(seq) # [B, T, D]
374
+ pos_emb = self.position_embedding(pos_ids) # [B, T, D]
375
+ hidden_states = self.input_dropout(token_emb + pos_emb)
376
+
377
+ # padding mask:True = pad
378
+ padding_mask = seq.eq(self.padding_idx) # [B, T]
379
+ attn_mask = self._build_causal_mask(seq_len=T, device=device) # [T, T]
380
+
381
+ for layer in self.layers:
382
+ hidden_states = layer(x=hidden_states, attn_mask=attn_mask, key_padding_mask=padding_mask)
383
+ hidden_states = self.final_norm(hidden_states) # [B, T, D]
384
+
385
+ valid_lengths = (~padding_mask).sum(dim=1) # [B]
386
+ last_index = (valid_lengths - 1).clamp(min=0)
387
+ last_hidden = hidden_states[torch.arange(B, device=device), last_index] # [B, D]
388
+
389
+ logits = self.lm_head(last_hidden) # [B, vocab_size]
390
+ return logits
391
+
392
+ def compute_loss(self, y_pred, y_true):
393
+ """
394
+ y_true: [B] or [B, 1], the id of the next item.
395
+ """
396
+ if y_true is None:
397
+ raise ValueError("[HSTU-compute_loss] Training requires y_true (next item id).")
398
+ labels = y_true.view(-1).long()
399
+ return self.loss_fn[0](y_pred, labels)
@@ -13,7 +13,116 @@ from typing import Literal
13
13
 
14
14
  from nextrec.basic.model import BaseMatchModel
15
15
  from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
16
- from nextrec.basic.layers import MLP, EmbeddingLayer, CapsuleNetwork
16
+ from nextrec.basic.layers import MLP, EmbeddingLayer
17
+
18
+ class MultiInterestSA(nn.Module):
19
+ """Multi-interest self-attention extractor from MIND (Li et al., 2019)."""
20
+
21
+ def __init__(self, embedding_dim, interest_num, hidden_dim=None):
22
+ super(MultiInterestSA, self).__init__()
23
+ self.embedding_dim = embedding_dim
24
+ self.interest_num = interest_num
25
+ if hidden_dim == None:
26
+ self.hidden_dim = self.embedding_dim * 4
27
+ self.W1 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.hidden_dim), requires_grad=True)
28
+ self.W2 = torch.nn.Parameter(torch.rand(self.hidden_dim, self.interest_num), requires_grad=True)
29
+ self.W3 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.embedding_dim), requires_grad=True)
30
+
31
+ def forward(self, seq_emb, mask=None):
32
+ H = torch.einsum('bse, ed -> bsd', seq_emb, self.W1).tanh()
33
+ if mask != None:
34
+ A = torch.einsum('bsd, dk -> bsk', H, self.W2) + -1.e9 * (1 - mask.float())
35
+ A = F.softmax(A, dim=1)
36
+ else:
37
+ A = F.softmax(torch.einsum('bsd, dk -> bsk', H, self.W2), dim=1)
38
+ A = A.permute(0, 2, 1)
39
+ multi_interest_emb = torch.matmul(A, seq_emb)
40
+ return multi_interest_emb
41
+
42
+
43
+ class CapsuleNetwork(nn.Module):
44
+ """Dynamic routing capsule network used in MIND (Li et al., 2019)."""
45
+
46
+ def __init__(self, embedding_dim, seq_len, bilinear_type=2, interest_num=4, routing_times=3, relu_layer=False):
47
+ super(CapsuleNetwork, self).__init__()
48
+ self.embedding_dim = embedding_dim # h
49
+ self.seq_len = seq_len # s
50
+ self.bilinear_type = bilinear_type
51
+ self.interest_num = interest_num
52
+ self.routing_times = routing_times
53
+
54
+ self.relu_layer = relu_layer
55
+ self.stop_grad = True
56
+ self.relu = nn.Sequential(nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), nn.ReLU())
57
+ if self.bilinear_type == 0: # MIND
58
+ self.linear = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
59
+ elif self.bilinear_type == 1:
60
+ self.linear = nn.Linear(self.embedding_dim, self.embedding_dim * self.interest_num, bias=False)
61
+ else:
62
+ self.w = nn.Parameter(torch.Tensor(1, self.seq_len, self.interest_num * self.embedding_dim, self.embedding_dim))
63
+ nn.init.xavier_uniform_(self.w)
64
+
65
+ def forward(self, item_eb, mask):
66
+ if self.bilinear_type == 0:
67
+ item_eb_hat = self.linear(item_eb)
68
+ item_eb_hat = item_eb_hat.repeat(1, 1, self.interest_num)
69
+ elif self.bilinear_type == 1:
70
+ item_eb_hat = self.linear(item_eb)
71
+ else:
72
+ u = torch.unsqueeze(item_eb, dim=2)
73
+ item_eb_hat = torch.sum(self.w[:, :self.seq_len, :, :] * u, dim=3)
74
+
75
+ item_eb_hat = torch.reshape(item_eb_hat, (-1, self.seq_len, self.interest_num, self.embedding_dim))
76
+ item_eb_hat = torch.transpose(item_eb_hat, 1, 2).contiguous()
77
+ item_eb_hat = torch.reshape(item_eb_hat, (-1, self.interest_num, self.seq_len, self.embedding_dim))
78
+
79
+ if self.stop_grad:
80
+ item_eb_hat_iter = item_eb_hat.detach()
81
+ else:
82
+ item_eb_hat_iter = item_eb_hat
83
+
84
+ if self.bilinear_type > 0:
85
+ capsule_weight = torch.zeros(item_eb_hat.shape[0],
86
+ self.interest_num,
87
+ self.seq_len,
88
+ device=item_eb.device,
89
+ requires_grad=False)
90
+ else:
91
+ capsule_weight = torch.randn(item_eb_hat.shape[0],
92
+ self.interest_num,
93
+ self.seq_len,
94
+ device=item_eb.device,
95
+ requires_grad=False)
96
+
97
+ for i in range(self.routing_times): # 动态路由传播3次
98
+ atten_mask = torch.unsqueeze(mask, 1).repeat(1, self.interest_num, 1)
99
+ paddings = torch.zeros_like(atten_mask, dtype=torch.float)
100
+
101
+ capsule_softmax_weight = F.softmax(capsule_weight, dim=-1)
102
+ capsule_softmax_weight = torch.where(torch.eq(atten_mask, 0), paddings, capsule_softmax_weight)
103
+ capsule_softmax_weight = torch.unsqueeze(capsule_softmax_weight, 2)
104
+
105
+ if i < 2:
106
+ interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat_iter)
107
+ cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
108
+ scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
109
+ interest_capsule = scalar_factor * interest_capsule
110
+
111
+ delta_weight = torch.matmul(item_eb_hat_iter, torch.transpose(interest_capsule, 2, 3).contiguous())
112
+ delta_weight = torch.reshape(delta_weight, (-1, self.interest_num, self.seq_len))
113
+ capsule_weight = capsule_weight + delta_weight
114
+ else:
115
+ interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat)
116
+ cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
117
+ scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
118
+ interest_capsule = scalar_factor * interest_capsule
119
+
120
+ interest_capsule = torch.reshape(interest_capsule, (-1, self.interest_num, self.embedding_dim))
121
+
122
+ if self.relu_layer:
123
+ interest_capsule = self.relu(interest_capsule)
124
+
125
+ return interest_capsule
17
126
 
18
127
 
19
128
  class MIND(BaseMatchModel):
@@ -1,7 +1,44 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
+ Checkpoint: edit on 29/11/2025
3
4
  Author: Yang Zhou,zyaztec@gmail.com
4
- Reference: [1] Ma X, Zhao L, Huang G, et al. Entire space multi-task model: An effective approach for estimating post-click conversion rate[C]//SIGIR. 2018: 1137-1140.
5
+ Reference:
6
+ [1] Ma X, Zhao L, Huang G, et al. Entire space multi-task model: An effective approach
7
+ for estimating post-click conversion rate[C]//SIGIR. 2018: 1137-1140.
8
+ (https://dl.acm.org/doi/10.1145/3209978.3210007)
9
+
10
+ Entire Space Multi-task Model (ESMM) targets CVR estimation by jointly optimizing
11
+ CTR and CTCVR on the full impression space, mitigating sample selection bias and
12
+ conversion sparsity. CTR predicts P(click | impression), CVR predicts P(conversion |
13
+ click), and their product forms CTCVR supervised on impression labels.
14
+
15
+ Workflow:
16
+ (1) Shared embeddings encode all features from impressions
17
+ (2) CTR tower outputs click probability conditioned on impression
18
+ (3) CVR tower outputs conversion probability conditioned on click
19
+ (4) CTCVR = CTR * CVR enables end-to-end training without filtering clicked data
20
+
21
+ Key Advantages:
22
+ - Trains on the entire impression space to remove selection bias
23
+ - Transfers rich click signals to sparse conversion prediction via shared embeddings
24
+ - Stable optimization by decomposing CTCVR into well-defined sub-tasks
25
+ - Simple architecture that can pair with other multi-task variants
26
+
27
+ ESMM(Entire Space Multi-task Model)用于 CVR 预估,通过在曝光全空间联合训练
28
+ CTR 与 CTCVR,缓解样本选择偏差和转化数据稀疏问题。CTR 预测 P(click|impression),
29
+ CVR 预测 P(conversion|click),二者相乘得到 CTCVR 并在曝光标签上直接监督。
30
+
31
+ 流程:
32
+ (1) 共享 embedding 统一处理曝光特征
33
+ (2) CTR 塔输出曝光下的点击概率
34
+ (3) CVR 塔输出点击后的转化概率
35
+ (4) CTR 与 CVR 相乘得到 CTCVR,无需只在点击子集上训练
36
+
37
+ 主要优点:
38
+ - 在曝光空间训练,避免样本选择偏差
39
+ - 通过共享表示将点击信号迁移到稀疏的转化任务
40
+ - 将 CTCVR 分解为子任务,优化稳定
41
+ - 结构简单,可与其它多任务方法组合使用
5
42
  """
6
43
 
7
44
  import torch
@@ -77,37 +114,22 @@ class ESMM(BaseModel):
77
114
 
78
115
  # All features
79
116
  self.all_features = dense_features + sparse_features + sequence_features
80
-
81
117
  # Shared embedding layer
82
118
  self.embedding = EmbeddingLayer(features=self.all_features)
119
+ input_dim = self.embedding.input_dim # Calculate input dimension, better way than below
120
+ # emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
121
+ # dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
122
+ # input_dim = emb_dim_total + dense_input_dim
83
123
 
84
- # Calculate input dimension
85
- emb_dim_total = sum([f.embedding_dim for f in self.all_features if not isinstance(f, DenseFeature)])
86
- dense_input_dim = sum([getattr(f, "embedding_dim", 1) or 1 for f in dense_features])
87
- input_dim = emb_dim_total + dense_input_dim
88
-
89
124
  # CTR tower
90
125
  self.ctr_tower = MLP(input_dim=input_dim, output_layer=True, **ctr_params)
91
126
 
92
127
  # CVR tower
93
128
  self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
94
- self.prediction_layer = PredictionLayer(
95
- task_type=self.task_type,
96
- task_dims=[1, 1]
97
- )
98
-
129
+ self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1, 1])
99
130
  # Register regularization weights
100
- self._register_regularization_weights(
101
- embedding_attr='embedding',
102
- include_modules=['ctr_tower', 'cvr_tower']
103
- )
104
-
105
- self.compile(
106
- optimizer=optimizer,
107
- optimizer_params=optimizer_params,
108
- loss=loss,
109
- loss_params=loss_params,
110
- )
131
+ self._register_regularization_weights(embedding_attr='embedding', include_modules=['ctr_tower', 'cvr_tower'])
132
+ self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
111
133
 
112
134
  def forward(self, x):
113
135
  # Get all embeddings and flatten
@@ -119,11 +141,8 @@ class ESMM(BaseModel):
119
141
  logits = torch.cat([ctr_logit, cvr_logit], dim=1)
120
142
  preds = self.prediction_layer(logits)
121
143
  ctr, cvr = preds.chunk(2, dim=1)
122
-
123
- # CTCVR prediction: P(click & conversion | impression) = P(click) * P(conversion | click)
124
144
  ctcvr = ctr * cvr # [B, 1]
125
145
 
126
- # Output: [CTR, CTCVR]
127
- # Note: We supervise CTR with click labels and CTCVR with conversion labels
146
+ # Output: [CTR, CTCVR], We supervise CTR with click labels and CTCVR with conversion labels
128
147
  y = torch.cat([ctr, ctcvr], dim=1) # [B, 2]
129
148
  return y # [B, 2], where y[:, 0] is CTR and y[:, 1] is CTCVR