nextrec 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/features.py +10 -23
  3. nextrec/basic/layers.py +18 -61
  4. nextrec/basic/loggers.py +1 -1
  5. nextrec/basic/metrics.py +55 -33
  6. nextrec/basic/model.py +258 -394
  7. nextrec/data/__init__.py +2 -2
  8. nextrec/data/data_utils.py +80 -4
  9. nextrec/data/dataloader.py +36 -57
  10. nextrec/data/preprocessor.py +5 -4
  11. nextrec/models/generative/__init__.py +5 -0
  12. nextrec/models/generative/hstu.py +399 -0
  13. nextrec/models/match/dssm.py +2 -2
  14. nextrec/models/match/dssm_v2.py +2 -2
  15. nextrec/models/match/mind.py +2 -2
  16. nextrec/models/match/sdm.py +2 -2
  17. nextrec/models/match/youtube_dnn.py +2 -2
  18. nextrec/models/multi_task/esmm.py +1 -1
  19. nextrec/models/multi_task/mmoe.py +1 -1
  20. nextrec/models/multi_task/ple.py +1 -1
  21. nextrec/models/multi_task/poso.py +1 -1
  22. nextrec/models/multi_task/share_bottom.py +1 -1
  23. nextrec/models/ranking/afm.py +1 -1
  24. nextrec/models/ranking/autoint.py +1 -1
  25. nextrec/models/ranking/dcn.py +1 -1
  26. nextrec/models/ranking/deepfm.py +1 -1
  27. nextrec/models/ranking/dien.py +1 -1
  28. nextrec/models/ranking/din.py +1 -1
  29. nextrec/models/ranking/fibinet.py +1 -1
  30. nextrec/models/ranking/fm.py +1 -1
  31. nextrec/models/ranking/masknet.py +2 -2
  32. nextrec/models/ranking/pnn.py +1 -1
  33. nextrec/models/ranking/widedeep.py +1 -1
  34. nextrec/models/ranking/xdeepfm.py +1 -1
  35. nextrec/utils/__init__.py +2 -1
  36. nextrec/utils/common.py +21 -2
  37. nextrec/utils/optimizer.py +7 -3
  38. {nextrec-0.3.1.dist-info → nextrec-0.3.3.dist-info}/METADATA +10 -4
  39. nextrec-0.3.3.dist-info/RECORD +57 -0
  40. nextrec-0.3.1.dist-info/RECORD +0 -56
  41. {nextrec-0.3.1.dist-info → nextrec-0.3.3.dist-info}/WHEEL +0 -0
  42. {nextrec-0.3.1.dist-info → nextrec-0.3.3.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,399 @@
1
+ """
2
+ [Info: this version is not released yet, i need to more research on source code and paper]
3
+ Date: create on 01/12/2025
4
+ Checkpoint: edit on 01/12/2025
5
+ Author: Yang Zhou, zyaztec@gmail.com
6
+ Reference:
7
+ [1] Meta AI. Generative Recommenders (HSTU encoder) — https://github.com/meta-recsys/generative-recommenders
8
+ [2] Ma W, Li P, Chen C, et al. Actions speak louder than words: Trillion-parameter sequential transducers for generative recommendations. arXiv:2402.17152.
9
+
10
+ Hierarchical Sequential Transduction Unit (HSTU) is the core encoder behind
11
+ Meta’s Generative Recommenders. It replaces softmax attention with lightweight
12
+ pointwise activations, enabling extremely deep stacks on long behavior sequences.
13
+
14
+ In each HSTU layer:
15
+ (1) Tokens are projected into four streams U, V, Q, K via a shared feed-forward block
16
+ (2) Softmax-free interactions combine QK^T with Relative Attention Bias (RAB) to encode distance
17
+ (3) Aggregated context is modulated by U-gating and mapped back through an output projection
18
+
19
+ Stacking layers yields an efficient causal encoder for next-item
20
+ generation. With a tied-embedding LM head, HSTU forms
21
+ a full generative recommendation model.
22
+
23
+ Key Advantages:
24
+ - Softmax-free attention scales better on deep/long sequences
25
+ - RAB captures temporal structure without extra attention heads
26
+ - Causal masking and padding-aware normalization fit real logs
27
+ - Weight tying reduces parameters and stabilizes training
28
+ - Serves as a drop-in backbone for generative recommendation
29
+
30
+ HSTU(层次化序列转导单元)是 Meta 生成式推荐的核心编码器,
31
+ 用点式激活替代 softmax 注意力,可在长序列上轻松堆叠深层结构。
32
+
33
+ 单层 HSTU 的主要步骤:
34
+ (1) 将输入一次性映射到 U、V、Q、K 四条通路
35
+ (2) 利用不含 softmax 的 QK^T 结合相对位置偏置(RAB)建模距离信息
36
+ (3) 用 U 对聚合上下文进行门控,再映射回输出空间
37
+
38
+ 多层堆叠后,可得到高效的因果编码器;与绑权 LM 头配合即可完成 next-item 预测。
39
+
40
+ 主要优势:
41
+ - 摆脱 softmax,在长序列、深层模型上更易扩展
42
+ - 相对位置偏置稳健刻画时序结构
43
+ - 因果 mask 与 padding 感知归一化贴合真实日志
44
+ - 绑权输出头降低参数量并提升训练稳定性
45
+ - 直接作为生成式推荐的骨干网络
46
+ """
47
+
48
+ from __future__ import annotations
49
+
50
+ import math
51
+ from typing import Optional
52
+
53
+ import torch
54
+ import torch.nn as nn
55
+ import torch.nn.functional as F
56
+
57
+ from nextrec.basic.model import BaseModel
58
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
59
+
60
+
61
+ def _relative_position_bucket(
62
+ relative_position: torch.Tensor,
63
+ num_buckets: int = 32,
64
+ max_distance: int = 128,
65
+ ) -> torch.Tensor:
66
+ """
67
+ map the relative position (i-j) to a bucket in [0, num_buckets).
68
+ """
69
+ # only need the negative part for causal attention
70
+ n = -relative_position
71
+ n = torch.clamp(n, min=0)
72
+
73
+ # when the distance is small, keep it exact
74
+ max_exact = num_buckets // 2
75
+ is_small = n < max_exact
76
+
77
+ # when the distance is too far, do log scaling
78
+ large_val = max_exact + ((torch.log(n.float() / max_exact + 1e-6) / math.log(max_distance / max_exact)) * (num_buckets - max_exact)).long()
79
+ large_val = torch.clamp(large_val, max=num_buckets - 1)
80
+
81
+ buckets = torch.where(is_small, n.long(), large_val)
82
+ return buckets
83
+
84
+
85
+ class RelativePositionBias(nn.Module):
86
+ """
87
+ Compute relative position bias (RAB) for HSTU attention.
88
+ The input is the sequence length T, output is [1, num_heads, seq_len, seq_len].
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ num_heads: int,
94
+ num_buckets: int = 32,
95
+ max_distance: int = 128,
96
+ ):
97
+ super().__init__()
98
+ self.num_heads = num_heads
99
+ self.num_buckets = num_buckets
100
+ self.max_distance = max_distance
101
+ self.embedding = nn.Embedding(num_buckets, num_heads)
102
+
103
+ def forward(self, seq_len: int, device: torch.device) -> torch.Tensor:
104
+ # positions: [T]
105
+ ctx = torch.arange(seq_len, device=device)[:, None]
106
+ mem = torch.arange(seq_len, device=device)[None, :]
107
+ rel_pos = mem - ctx # a matrix to describe all relative positions for each [i,j] pair, shape = [seq_len, seq_len]
108
+ buckets = _relative_position_bucket(rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance,) # map to buckets
109
+ values = self.embedding(buckets) # embedding vector for each [i,j] pair, shape = [seq_len, seq_len, embedding_dim=num_heads]
110
+ return values.permute(2, 0, 1).unsqueeze(0) # [1, num_heads, seq_len, seq_len]
111
+
112
+ class HSTUPointwiseAttention(nn.Module):
113
+ """
114
+ Pointwise aggregation attention that implements HSTU without softmax:
115
+ 1) [U, V, Q, K] = split( φ1(f1(X)) ), U: gate, V: value, Q: query, K: key
116
+ 2) AV = φ2(QK^T + rab) V / N, av is attention-weighted value
117
+ 3) Y = f2( Norm(AV) ⊙ U ), y is output
118
+ φ1, φ2 use SiLU; Norm uses LayerNorm.
119
+ """
120
+
121
+ def __init__(
122
+ self,
123
+ d_model: int,
124
+ num_heads: int,
125
+ dropout: float = 0.1,
126
+ alpha: float | None = None
127
+ ):
128
+ super().__init__()
129
+ if d_model % num_heads != 0:
130
+ raise ValueError(f"[HSTUPointwiseAttention Error] d_model({d_model}) % num_heads({num_heads}) != 0")
131
+
132
+ self.d_model = d_model
133
+ self.num_heads = num_heads
134
+ self.d_head = d_model // num_heads
135
+ self.alpha = alpha if alpha is not None else (self.d_head ** -0.5)
136
+ # project input to 4 * d_model for U, V, Q, K
137
+ self.in_proj = nn.Linear(d_model, 4 * d_model, bias=True)
138
+ # project output back to d_model
139
+ self.out_proj = nn.Linear(d_model, d_model, bias=True)
140
+ self.dropout = nn.Dropout(dropout)
141
+ self.norm = nn.LayerNorm(d_model)
142
+
143
+ def _reshape_heads(self, x: torch.Tensor) -> torch.Tensor:
144
+ """
145
+ [B, T, D] -> [B, H, T, d_head]
146
+ """
147
+ B, T, D = x.shape
148
+ return x.view(B, T, self.num_heads, self.d_head).transpose(1, 2)
149
+
150
+ def forward(
151
+ self,
152
+ x: torch.Tensor,
153
+ attn_mask: Optional[torch.Tensor] = None, # [T, T] with 0 or -inf
154
+ key_padding_mask: Optional[torch.Tensor] = None, # [B, T], True = pad
155
+ rab: Optional[torch.Tensor] = None, # [1, H, T, T] or None
156
+ ) -> torch.Tensor:
157
+ B, T, D = x.shape
158
+
159
+ # Eq.(1): [U, V, Q, K] = split( φ1(f1(X)) )
160
+ h = F.silu(self.in_proj(x)) # [B, T, 4D]
161
+ U, V, Q, K = h.chunk(4, dim=-1) # each [B, T, D]
162
+
163
+ Qh = self._reshape_heads(Q) # [B, H, T, d_head]
164
+ Kh = self._reshape_heads(K) # [B, H, T, d_head]
165
+ Vh = self._reshape_heads(V) # [B, H, T, d_head]
166
+ Uh = self._reshape_heads(U) # [B, H, T, d_head]
167
+
168
+ # attention logits: QK^T (without 1/sqrt(d) and softmax)
169
+ logits = torch.matmul(Qh, Kh.transpose(-2, -1)) * self.alpha # [B, H, T, T]
170
+
171
+ # add relative position bias (rab^p), and future extensible rab^t
172
+ if rab is not None:
173
+ # rab: [1, H, T, T] or [B, H, T, T]
174
+ logits = logits + rab
175
+
176
+ # construct an "allowed" mask to calculate N
177
+ # 1 indicates that the (query i, key j) pair is a valid attention pair; 0 indicates it is masked out
178
+ allowed = torch.ones_like(logits, dtype=torch.float) # [B, H, T, T]
179
+
180
+ # causal mask: attn_mask is usually an upper triangular matrix of -inf with shape [T, T]
181
+ if attn_mask is not None:
182
+ allowed = allowed * (attn_mask.view(1, 1, T, T) == 0).float()
183
+ logits = logits + attn_mask.view(1, 1, T, T)
184
+
185
+ # padding mask: key_padding_mask is usually [B, T], True = pad
186
+ if key_padding_mask is not None:
187
+ # valid: 1 for non-pad, 0 for pad
188
+ valid = (~key_padding_mask).float() # [B, T]
189
+ valid = valid.view(B, 1, 1, T) # [B, 1, 1, T]
190
+ allowed = allowed * valid
191
+ logits = logits.masked_fill(valid == 0, float("-inf"))
192
+
193
+ # Eq.(2): A(X)V(X) = φ2(QK^T + rab) V(X) / N
194
+ attn = F.silu(logits) # [B, H, T, T]
195
+ denom = allowed.sum(dim=-1, keepdim=True) # [B, H, T, 1]
196
+ denom = denom.clamp(min=1.0)
197
+
198
+ attn = attn / denom # [B, H, T, T]
199
+ AV = torch.matmul(attn, Vh) # [B, H, T, d_head]
200
+ AV = AV.transpose(1, 2).contiguous().view(B, T, D) # reshape back to [B, T, D]
201
+ U_flat = Uh.transpose(1, 2).contiguous().view(B, T, D)
202
+ y = self.out_proj(self.dropout(self.norm(AV) * U_flat)) # [B, T, D]
203
+ return y
204
+
205
+
206
+ class HSTULayer(nn.Module):
207
+ """
208
+ HSTUPointwiseAttention + Residual Connection
209
+ """
210
+
211
+ def __init__(
212
+ self,
213
+ d_model: int,
214
+ num_heads: int,
215
+ dropout: float = 0.1,
216
+ use_rab_pos: bool = True,
217
+ rab_num_buckets: int = 32,
218
+ rab_max_distance: int = 128,
219
+ ):
220
+ super().__init__()
221
+ self.attn = HSTUPointwiseAttention(d_model=d_model, num_heads=num_heads, dropout=dropout)
222
+ self.dropout = nn.Dropout(dropout)
223
+ self.use_rab_pos = use_rab_pos
224
+ self.rel_pos_bias = (RelativePositionBias(num_heads=num_heads, num_buckets=rab_num_buckets, max_distance=rab_max_distance) if use_rab_pos else None)
225
+
226
+ def forward(
227
+ self,
228
+ x: torch.Tensor,
229
+ attn_mask: Optional[torch.Tensor] = None,
230
+ key_padding_mask: Optional[torch.Tensor] = None,
231
+ ) -> torch.Tensor:
232
+ """
233
+ x: [B, T, D]
234
+ """
235
+ B, T, D = x.shape
236
+ device = x.device
237
+ rab = None
238
+ if self.use_rab_pos:
239
+ rab = self.rel_pos_bias(seq_len=T, device=device) # [1, H, T, T]
240
+ out = self.attn(x=x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, rab=rab)
241
+ return x + self.dropout(out)
242
+
243
+
244
+ class HSTU(BaseModel):
245
+ """
246
+ HSTU encoder for next-item prediction in a causal, generative setup.
247
+ Pipeline:
248
+ 1) Embed tokens + positions from the behavior history
249
+ 2) Apply stacked HSTU layers with causal mask and optional RAB
250
+ 3) Use the last valid position to produce next-item logits via tied LM head
251
+ """
252
+
253
+ @property
254
+ def model_name(self) -> str:
255
+ return "HSTU"
256
+
257
+ @property
258
+ def task_type(self) -> str:
259
+ return "multiclass"
260
+
261
+ def __init__(
262
+ self,
263
+ sequence_features: list[SequenceFeature],
264
+ dense_features: Optional[list[DenseFeature]] = None,
265
+ sparse_features: Optional[list[SparseFeature]] = None,
266
+ d_model: Optional[int] = None,
267
+ num_heads: int = 8,
268
+ num_layers: int = 4,
269
+ max_seq_len: int = 200,
270
+ dropout: float = 0.1,
271
+ # RAB settings
272
+ use_rab_pos: bool = True,
273
+ rab_num_buckets: int = 32,
274
+ rab_max_distance: int = 128,
275
+
276
+ tie_embeddings: bool = True,
277
+ target: Optional[list[str] | str] = None,
278
+ optimizer: str = "adam",
279
+ optimizer_params: Optional[dict] = None,
280
+ scheduler: Optional[str] = None,
281
+ scheduler_params: Optional[dict] = None,
282
+ loss_params: Optional[dict] = None,
283
+ embedding_l1_reg: float = 0.0,
284
+ dense_l1_reg: float = 0.0,
285
+ embedding_l2_reg: float = 0.0,
286
+ dense_l2_reg: float = 0.0,
287
+ device: str = "cpu",
288
+ **kwargs,
289
+ ):
290
+ if not sequence_features:
291
+ raise ValueError("[HSTU Error] HSTU requires at least one SequenceFeature (user behavior history).")
292
+
293
+ # demo version: use the first SequenceFeature as the main sequence
294
+ self.history_feature = sequence_features[0]
295
+
296
+ hidden_dim = d_model or max(int(getattr(self.history_feature, "embedding_dim", 0) or 0), 32)
297
+ # Make hidden_dim divisible by num_heads
298
+ if hidden_dim % num_heads != 0:
299
+ hidden_dim = num_heads * math.ceil(hidden_dim / num_heads)
300
+
301
+ self.padding_idx = self.history_feature.padding_idx if self.history_feature.padding_idx is not None else 0
302
+ self.vocab_size = self.history_feature.vocab_size
303
+ self.max_seq_len = max_seq_len
304
+
305
+ super().__init__(
306
+ dense_features=dense_features,
307
+ sparse_features=sparse_features,
308
+ sequence_features=sequence_features,
309
+ target=target,
310
+ task=self.task_type,
311
+ device=device,
312
+ embedding_l1_reg=embedding_l1_reg,
313
+ dense_l1_reg=dense_l1_reg,
314
+ embedding_l2_reg=embedding_l2_reg,
315
+ dense_l2_reg=dense_l2_reg,
316
+ **kwargs,
317
+ )
318
+
319
+ # token & position embedding (paper usually includes pos embedding / RAB in encoder)
320
+ self.token_embedding = nn.Embedding(
321
+ num_embeddings=self.vocab_size,
322
+ embedding_dim=hidden_dim,
323
+ padding_idx=self.padding_idx,
324
+ )
325
+ self.position_embedding = nn.Embedding(max_seq_len, hidden_dim)
326
+ self.input_dropout = nn.Dropout(dropout)
327
+
328
+ # HSTU layers
329
+ self.layers = nn.ModuleList([HSTULayer(d_model=hidden_dim, num_heads=num_heads, dropout=dropout, use_rab_pos=use_rab_pos,
330
+ rab_num_buckets=rab_num_buckets, rab_max_distance=rab_max_distance) for _ in range(num_layers)])
331
+
332
+ self.final_norm = nn.LayerNorm(hidden_dim)
333
+ self.lm_head = nn.Linear(hidden_dim, self.vocab_size, bias=False)
334
+ if tie_embeddings:
335
+ self.lm_head.weight = self.token_embedding.weight
336
+
337
+ # causal mask buffer
338
+ self.register_buffer("causal_mask", torch.empty(0), persistent=False)
339
+ self.ignore_index = self.padding_idx if self.padding_idx is not None else -100
340
+
341
+ optimizer_params = optimizer_params or {}
342
+ scheduler_params = scheduler_params or {}
343
+ loss_params = loss_params or {}
344
+ loss_params.setdefault("ignore_index", self.ignore_index)
345
+
346
+ self.compile(optimizer=optimizer, optimizer_params=optimizer_params, scheduler=scheduler, scheduler_params=scheduler_params, loss="crossentropy", loss_params=loss_params)
347
+ self.register_regularization_weights(embedding_attr="token_embedding", include_modules=["layers", "lm_head"])
348
+
349
+ def _build_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
350
+ """
351
+ build causal mask of shape [T, T]: upper triangle is -inf, others are 0.
352
+ This will be added to the logits to simulate causal structure.
353
+ """
354
+ if self.causal_mask.numel() == 0 or self.causal_mask.size(0) < seq_len:
355
+ mask = torch.full((seq_len, seq_len), float("-inf"), device=device)
356
+ mask = torch.triu(mask, diagonal=1)
357
+ self.causal_mask = mask
358
+ return self.causal_mask[:seq_len, :seq_len]
359
+
360
+ def _trim_sequence(self, seq: torch.Tensor) -> torch.Tensor:
361
+ if seq.size(1) <= self.max_seq_len:
362
+ return seq
363
+ return seq[:, -self.max_seq_len :]
364
+
365
+ def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
366
+ seq = x[self.history_feature.name].long() # [B, T_raw]
367
+ seq = self._trim_sequence(seq) # [B, T]
368
+
369
+ B, T = seq.shape
370
+ device = seq.device
371
+ # position ids: [B, T]
372
+ pos_ids = torch.arange(T, device=device).unsqueeze(0).expand(B, -1)
373
+ token_emb = self.token_embedding(seq) # [B, T, D]
374
+ pos_emb = self.position_embedding(pos_ids) # [B, T, D]
375
+ hidden_states = self.input_dropout(token_emb + pos_emb)
376
+
377
+ # padding mask:True = pad
378
+ padding_mask = seq.eq(self.padding_idx) # [B, T]
379
+ attn_mask = self._build_causal_mask(seq_len=T, device=device) # [T, T]
380
+
381
+ for layer in self.layers:
382
+ hidden_states = layer(x=hidden_states, attn_mask=attn_mask, key_padding_mask=padding_mask)
383
+ hidden_states = self.final_norm(hidden_states) # [B, T, D]
384
+
385
+ valid_lengths = (~padding_mask).sum(dim=1) # [B]
386
+ last_index = (valid_lengths - 1).clamp(min=0)
387
+ last_hidden = hidden_states[torch.arange(B, device=device), last_index] # [B, D]
388
+
389
+ logits = self.lm_head(last_hidden) # [B, vocab_size]
390
+ return logits
391
+
392
+ def compute_loss(self, y_pred, y_true):
393
+ """
394
+ y_true: [B] or [B, 1], the id of the next item.
395
+ """
396
+ if y_true is None:
397
+ raise ValueError("[HSTU-compute_loss] Training requires y_true (next item id).")
398
+ labels = y_true.view(-1).long()
399
+ return self.loss_fn[0](y_pred, labels)
@@ -143,11 +143,11 @@ class DSSM(BaseMatchModel):
143
143
  activation=dnn_activation
144
144
  )
145
145
 
146
- self._register_regularization_weights(
146
+ self.register_regularization_weights(
147
147
  embedding_attr='user_embedding',
148
148
  include_modules=['user_dnn']
149
149
  )
150
- self._register_regularization_weights(
150
+ self.register_regularization_weights(
151
151
  embedding_attr='item_embedding',
152
152
  include_modules=['item_dnn']
153
153
  )
@@ -134,11 +134,11 @@ class DSSM_v2(BaseMatchModel):
134
134
  activation=dnn_activation
135
135
  )
136
136
 
137
- self._register_regularization_weights(
137
+ self.register_regularization_weights(
138
138
  embedding_attr='user_embedding',
139
139
  include_modules=['user_dnn']
140
140
  )
141
- self._register_regularization_weights(
141
+ self.register_regularization_weights(
142
142
  embedding_attr='item_embedding',
143
143
  include_modules=['item_dnn']
144
144
  )
@@ -258,11 +258,11 @@ class MIND(BaseMatchModel):
258
258
  else:
259
259
  self.item_dnn = None
260
260
 
261
- self._register_regularization_weights(
261
+ self.register_regularization_weights(
262
262
  embedding_attr='user_embedding',
263
263
  include_modules=['capsule_network']
264
264
  )
265
- self._register_regularization_weights(
265
+ self.register_regularization_weights(
266
266
  embedding_attr='item_embedding',
267
267
  include_modules=['item_dnn'] if self.item_dnn else []
268
268
  )
@@ -176,11 +176,11 @@ class SDM(BaseMatchModel):
176
176
  else:
177
177
  self.item_dnn = None
178
178
 
179
- self._register_regularization_weights(
179
+ self.register_regularization_weights(
180
180
  embedding_attr='user_embedding',
181
181
  include_modules=['rnn', 'user_dnn']
182
182
  )
183
- self._register_regularization_weights(
183
+ self.register_regularization_weights(
184
184
  embedding_attr='item_embedding',
185
185
  include_modules=['item_dnn'] if self.item_dnn else []
186
186
  )
@@ -140,11 +140,11 @@ class YoutubeDNN(BaseMatchModel):
140
140
  activation=dnn_activation
141
141
  )
142
142
 
143
- self._register_regularization_weights(
143
+ self.register_regularization_weights(
144
144
  embedding_attr='user_embedding',
145
145
  include_modules=['user_dnn']
146
146
  )
147
- self._register_regularization_weights(
147
+ self.register_regularization_weights(
148
148
  embedding_attr='item_embedding',
149
149
  include_modules=['item_dnn']
150
150
  )
@@ -128,7 +128,7 @@ class ESMM(BaseModel):
128
128
  self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
129
129
  self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1, 1])
130
130
  # Register regularization weights
131
- self._register_regularization_weights(embedding_attr='embedding', include_modules=['ctr_tower', 'cvr_tower'])
131
+ self.register_regularization_weights(embedding_attr='embedding', include_modules=['ctr_tower', 'cvr_tower'])
132
132
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
133
133
 
134
134
  def forward(self, x):
@@ -146,7 +146,7 @@ class MMOE(BaseModel):
146
146
  self.towers.append(tower)
147
147
  self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1] * self.num_tasks)
148
148
  # Register regularization weights
149
- self._register_regularization_weights(embedding_attr='embedding', include_modules=['experts', 'gates', 'towers'])
149
+ self.register_regularization_weights(embedding_attr='embedding', include_modules=['experts', 'gates', 'towers'])
150
150
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params,)
151
151
 
152
152
  def forward(self, x):
@@ -249,7 +249,7 @@ class PLE(BaseModel):
249
249
  self.towers.append(tower)
250
250
  self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1] * self.num_tasks)
251
251
  # Register regularization weights
252
- self._register_regularization_weights(embedding_attr='embedding', include_modules=['cgc_layers', 'towers'])
252
+ self.register_regularization_weights(embedding_attr='embedding', include_modules=['cgc_layers', 'towers'])
253
253
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=self.loss, loss_params=loss_params)
254
254
 
255
255
  def forward(self, x):
@@ -389,7 +389,7 @@ class POSO(BaseModel):
389
389
  self.tower_heads = None
390
390
  self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1] * self.num_tasks,)
391
391
  include_modules = ["towers", "tower_heads"] if self.architecture == "mlp" else ["mmoe", "towers"]
392
- self._register_regularization_weights(embedding_attr="embedding", include_modules=include_modules)
392
+ self.register_regularization_weights(embedding_attr="embedding", include_modules=include_modules)
393
393
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
394
394
 
395
395
  def forward(self, x):
@@ -122,7 +122,7 @@ class ShareBottom(BaseModel):
122
122
  self.towers.append(tower)
123
123
  self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1] * self.num_tasks)
124
124
  # Register regularization weights
125
- self._register_regularization_weights(embedding_attr='embedding', include_modules=['bottom', 'towers'])
125
+ self.register_regularization_weights(embedding_attr='embedding', include_modules=['bottom', 'towers'])
126
126
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
127
127
 
128
128
  def forward(self, x):
@@ -81,7 +81,7 @@ class AFM(BaseModel):
81
81
  self.prediction_layer = PredictionLayer(task_type=self.task_type)
82
82
 
83
83
  # Register regularization weights
84
- self._register_regularization_weights(
84
+ self.register_regularization_weights(
85
85
  embedding_attr='embedding',
86
86
  include_modules=['linear', 'attention_linear', 'attention_p', 'output_projection']
87
87
  )
@@ -150,7 +150,7 @@ class AutoInt(BaseModel):
150
150
  self.prediction_layer = PredictionLayer(task_type=self.task_type)
151
151
 
152
152
  # Register regularization weights
153
- self._register_regularization_weights(
153
+ self.register_regularization_weights(
154
154
  embedding_attr='embedding',
155
155
  include_modules=['projection_layers', 'attention_layers', 'fc']
156
156
  )
@@ -109,7 +109,7 @@ class DCN(BaseModel):
109
109
  self.prediction_layer = PredictionLayer(task_type=self.task_type)
110
110
 
111
111
  # Register regularization weights
112
- self._register_regularization_weights(
112
+ self.register_regularization_weights(
113
113
  embedding_attr='embedding',
114
114
  include_modules=['cross_network', 'mlp', 'final_layer']
115
115
  )
@@ -107,7 +107,7 @@ class DeepFM(BaseModel):
107
107
  self.prediction_layer = PredictionLayer(task_type=self.task_type)
108
108
 
109
109
  # Register regularization weights
110
- self._register_regularization_weights(embedding_attr='embedding', include_modules=['linear', 'mlp'])
110
+ self.register_regularization_weights(embedding_attr='embedding', include_modules=['linear', 'mlp'])
111
111
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
112
112
 
113
113
  def forward(self, x):
@@ -237,7 +237,7 @@ class DIEN(BaseModel):
237
237
  self.mlp = MLP(input_dim=mlp_input_dim, **mlp_params)
238
238
  self.prediction_layer = PredictionLayer(task_type=self.task_type)
239
239
  # Register regularization weights
240
- self._register_regularization_weights(embedding_attr='embedding', include_modules=['interest_extractor', 'interest_evolution', 'attention_layer', 'mlp', 'candidate_proj'])
240
+ self.register_regularization_weights(embedding_attr='embedding', include_modules=['interest_extractor', 'interest_evolution', 'attention_layer', 'mlp', 'candidate_proj'])
241
241
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
242
242
 
243
243
  def forward(self, x):
@@ -108,7 +108,7 @@ class DIN(BaseModel):
108
108
  self.prediction_layer = PredictionLayer(task_type=self.task_type)
109
109
 
110
110
  # Register regularization weights
111
- self._register_regularization_weights(
111
+ self.register_regularization_weights(
112
112
  embedding_attr='embedding',
113
113
  include_modules=['attention', 'mlp', 'candidate_attention_proj']
114
114
  )
@@ -104,7 +104,7 @@ class FiBiNET(BaseModel):
104
104
  self.prediction_layer = PredictionLayer(task_type=self.task_type)
105
105
 
106
106
  # Register regularization weights
107
- self._register_regularization_weights(
107
+ self.register_regularization_weights(
108
108
  embedding_attr='embedding',
109
109
  include_modules=['linear', 'senet', 'bilinear_standard', 'bilinear_senet', 'mlp']
110
110
  )
@@ -69,7 +69,7 @@ class FM(BaseModel):
69
69
  self.prediction_layer = PredictionLayer(task_type=self.task_type)
70
70
 
71
71
  # Register regularization weights
72
- self._register_regularization_weights(
72
+ self.register_regularization_weights(
73
73
  embedding_attr='embedding',
74
74
  include_modules=['linear']
75
75
  )
@@ -234,10 +234,10 @@ class MaskNet(BaseModel):
234
234
  self.prediction_layer = PredictionLayer(task_type=self.task_type)
235
235
 
236
236
  if self.model_type == "serial":
237
- self._register_regularization_weights(embedding_attr="embedding", include_modules=["mask_blocks", "output_layer"],)
237
+ self.register_regularization_weights(embedding_attr="embedding", include_modules=["mask_blocks", "output_layer"],)
238
238
  # serial
239
239
  else:
240
- self._register_regularization_weights(embedding_attr="embedding", include_modules=["mask_blocks", "final_mlp"])
240
+ self.register_regularization_weights(embedding_attr="embedding", include_modules=["mask_blocks", "final_mlp"])
241
241
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
242
242
 
243
243
  def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
@@ -91,7 +91,7 @@ class PNN(BaseModel):
91
91
  modules = ['mlp']
92
92
  if self.product_type == "outer":
93
93
  modules.append('kernel')
94
- self._register_regularization_weights(
94
+ self.register_regularization_weights(
95
95
  embedding_attr='embedding',
96
96
  include_modules=modules
97
97
  )
@@ -111,7 +111,7 @@ class WideDeep(BaseModel):
111
111
  self.mlp = MLP(input_dim=input_dim, **mlp_params)
112
112
  self.prediction_layer = PredictionLayer(task_type=self.task_type)
113
113
  # Register regularization weights
114
- self._register_regularization_weights(embedding_attr='embedding', include_modules=['linear', 'mlp'])
114
+ self.register_regularization_weights(embedding_attr='embedding', include_modules=['linear', 'mlp'])
115
115
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
116
116
 
117
117
  def forward(self, x):
@@ -121,7 +121,7 @@ class xDeepFM(BaseModel):
121
121
  self.prediction_layer = PredictionLayer(task_type=self.task_type)
122
122
 
123
123
  # Register regularization weights
124
- self._register_regularization_weights(
124
+ self.register_regularization_weights(
125
125
  embedding_attr='embedding',
126
126
  include_modules=['linear', 'cin', 'mlp']
127
127
  )