nextrec 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/loggers.py +1 -1
- nextrec/basic/model.py +11 -5
- nextrec/models/generative/__init__.py +5 -0
- nextrec/models/generative/hstu.py +399 -0
- nextrec/utils/optimizer.py +7 -3
- {nextrec-0.3.1.dist-info → nextrec-0.3.2.dist-info}/METADATA +10 -4
- {nextrec-0.3.1.dist-info → nextrec-0.3.2.dist-info}/RECORD +10 -9
- {nextrec-0.3.1.dist-info → nextrec-0.3.2.dist-info}/WHEEL +0 -0
- {nextrec-0.3.1.dist-info → nextrec-0.3.2.dist-info}/licenses/LICENSE +0 -0
nextrec/__version__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.3.
|
|
1
|
+
__version__ = "0.3.2"
|
nextrec/basic/loggers.py
CHANGED
|
@@ -106,7 +106,7 @@ def setup_logger(session_id: str | os.PathLike | None = None):
|
|
|
106
106
|
|
|
107
107
|
console_format = '%(message)s'
|
|
108
108
|
file_format = '%(asctime)s - %(levelname)s - %(message)s'
|
|
109
|
-
date_format = '%H:%M:%S'
|
|
109
|
+
date_format = '%Y-%m-%d %H:%M:%S'
|
|
110
110
|
|
|
111
111
|
logger = logging.getLogger()
|
|
112
112
|
logger.setLevel(logging.INFO)
|
nextrec/basic/model.py
CHANGED
|
@@ -216,10 +216,15 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
216
216
|
return train_loader, valid_split
|
|
217
217
|
|
|
218
218
|
def compile(
|
|
219
|
-
self,
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
219
|
+
self,
|
|
220
|
+
optimizer: str | torch.optim.Optimizer = "adam",
|
|
221
|
+
optimizer_params: dict | None = None,
|
|
222
|
+
scheduler: str | torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | type[torch.optim.lr_scheduler.LRScheduler] | None = None,
|
|
223
|
+
scheduler_params: dict | None = None,
|
|
224
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
225
|
+
loss_params: dict | list[dict] | None = None,
|
|
226
|
+
loss_weights: int | float | list[int | float] | None = None,
|
|
227
|
+
):
|
|
223
228
|
optimizer_params = optimizer_params or {}
|
|
224
229
|
self._optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
225
230
|
self._optimizer_params = optimizer_params
|
|
@@ -1081,6 +1086,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
|
|
|
1081
1086
|
logger.info(f" Early Stop Patience: {self._early_stop_patience}")
|
|
1082
1087
|
logger.info(f" Max Gradient Norm: {self._max_gradient_norm}")
|
|
1083
1088
|
logger.info(f" Session ID: {self.session_id}")
|
|
1089
|
+
logger.info(f" Features Config Path: {self.features_config_path}")
|
|
1084
1090
|
logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
|
|
1085
1091
|
|
|
1086
1092
|
logger.info("")
|
|
@@ -1195,7 +1201,7 @@ class BaseMatchModel(BaseModel):
|
|
|
1195
1201
|
def compile(self,
|
|
1196
1202
|
optimizer: str | torch.optim.Optimizer = "adam",
|
|
1197
1203
|
optimizer_params: dict | None = None,
|
|
1198
|
-
scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
|
|
1204
|
+
scheduler: str | torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | type[torch.optim.lr_scheduler.LRScheduler] | None = None,
|
|
1199
1205
|
scheduler_params: dict | None = None,
|
|
1200
1206
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
1201
1207
|
loss_params: dict | list[dict] | None = None):
|
|
@@ -0,0 +1,399 @@
|
|
|
1
|
+
"""
|
|
2
|
+
[Info: this version is not released yet, i need to more research on source code and paper]
|
|
3
|
+
Date: create on 01/12/2025
|
|
4
|
+
Checkpoint: edit on 01/12/2025
|
|
5
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
6
|
+
Reference:
|
|
7
|
+
[1] Meta AI. Generative Recommenders (HSTU encoder) — https://github.com/meta-recsys/generative-recommenders
|
|
8
|
+
[2] Ma W, Li P, Chen C, et al. Actions speak louder than words: Trillion-parameter sequential transducers for generative recommendations. arXiv:2402.17152.
|
|
9
|
+
|
|
10
|
+
Hierarchical Sequential Transduction Unit (HSTU) is the core encoder behind
|
|
11
|
+
Meta’s Generative Recommenders. It replaces softmax attention with lightweight
|
|
12
|
+
pointwise activations, enabling extremely deep stacks on long behavior sequences.
|
|
13
|
+
|
|
14
|
+
In each HSTU layer:
|
|
15
|
+
(1) Tokens are projected into four streams U, V, Q, K via a shared feed-forward block
|
|
16
|
+
(2) Softmax-free interactions combine QK^T with Relative Attention Bias (RAB) to encode distance
|
|
17
|
+
(3) Aggregated context is modulated by U-gating and mapped back through an output projection
|
|
18
|
+
|
|
19
|
+
Stacking layers yields an efficient causal encoder for next-item
|
|
20
|
+
generation. With a tied-embedding LM head, HSTU forms
|
|
21
|
+
a full generative recommendation model.
|
|
22
|
+
|
|
23
|
+
Key Advantages:
|
|
24
|
+
- Softmax-free attention scales better on deep/long sequences
|
|
25
|
+
- RAB captures temporal structure without extra attention heads
|
|
26
|
+
- Causal masking and padding-aware normalization fit real logs
|
|
27
|
+
- Weight tying reduces parameters and stabilizes training
|
|
28
|
+
- Serves as a drop-in backbone for generative recommendation
|
|
29
|
+
|
|
30
|
+
HSTU(层次化序列转导单元)是 Meta 生成式推荐的核心编码器,
|
|
31
|
+
用点式激活替代 softmax 注意力,可在长序列上轻松堆叠深层结构。
|
|
32
|
+
|
|
33
|
+
单层 HSTU 的主要步骤:
|
|
34
|
+
(1) 将输入一次性映射到 U、V、Q、K 四条通路
|
|
35
|
+
(2) 利用不含 softmax 的 QK^T 结合相对位置偏置(RAB)建模距离信息
|
|
36
|
+
(3) 用 U 对聚合上下文进行门控,再映射回输出空间
|
|
37
|
+
|
|
38
|
+
多层堆叠后,可得到高效的因果编码器;与绑权 LM 头配合即可完成 next-item 预测。
|
|
39
|
+
|
|
40
|
+
主要优势:
|
|
41
|
+
- 摆脱 softmax,在长序列、深层模型上更易扩展
|
|
42
|
+
- 相对位置偏置稳健刻画时序结构
|
|
43
|
+
- 因果 mask 与 padding 感知归一化贴合真实日志
|
|
44
|
+
- 绑权输出头降低参数量并提升训练稳定性
|
|
45
|
+
- 直接作为生成式推荐的骨干网络
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
from __future__ import annotations
|
|
49
|
+
|
|
50
|
+
import math
|
|
51
|
+
from typing import Optional
|
|
52
|
+
|
|
53
|
+
import torch
|
|
54
|
+
import torch.nn as nn
|
|
55
|
+
import torch.nn.functional as F
|
|
56
|
+
|
|
57
|
+
from nextrec.basic.model import BaseModel
|
|
58
|
+
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _relative_position_bucket(
|
|
62
|
+
relative_position: torch.Tensor,
|
|
63
|
+
num_buckets: int = 32,
|
|
64
|
+
max_distance: int = 128,
|
|
65
|
+
) -> torch.Tensor:
|
|
66
|
+
"""
|
|
67
|
+
map the relative position (i-j) to a bucket in [0, num_buckets).
|
|
68
|
+
"""
|
|
69
|
+
# only need the negative part for causal attention
|
|
70
|
+
n = -relative_position
|
|
71
|
+
n = torch.clamp(n, min=0)
|
|
72
|
+
|
|
73
|
+
# when the distance is small, keep it exact
|
|
74
|
+
max_exact = num_buckets // 2
|
|
75
|
+
is_small = n < max_exact
|
|
76
|
+
|
|
77
|
+
# when the distance is too far, do log scaling
|
|
78
|
+
large_val = max_exact + ((torch.log(n.float() / max_exact + 1e-6) / math.log(max_distance / max_exact)) * (num_buckets - max_exact)).long()
|
|
79
|
+
large_val = torch.clamp(large_val, max=num_buckets - 1)
|
|
80
|
+
|
|
81
|
+
buckets = torch.where(is_small, n.long(), large_val)
|
|
82
|
+
return buckets
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class RelativePositionBias(nn.Module):
|
|
86
|
+
"""
|
|
87
|
+
Compute relative position bias (RAB) for HSTU attention.
|
|
88
|
+
The input is the sequence length T, output is [1, num_heads, seq_len, seq_len].
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
num_heads: int,
|
|
94
|
+
num_buckets: int = 32,
|
|
95
|
+
max_distance: int = 128,
|
|
96
|
+
):
|
|
97
|
+
super().__init__()
|
|
98
|
+
self.num_heads = num_heads
|
|
99
|
+
self.num_buckets = num_buckets
|
|
100
|
+
self.max_distance = max_distance
|
|
101
|
+
self.embedding = nn.Embedding(num_buckets, num_heads)
|
|
102
|
+
|
|
103
|
+
def forward(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
|
104
|
+
# positions: [T]
|
|
105
|
+
ctx = torch.arange(seq_len, device=device)[:, None]
|
|
106
|
+
mem = torch.arange(seq_len, device=device)[None, :]
|
|
107
|
+
rel_pos = mem - ctx # a matrix to describe all relative positions for each [i,j] pair, shape = [seq_len, seq_len]
|
|
108
|
+
buckets = _relative_position_bucket(rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance,) # map to buckets
|
|
109
|
+
values = self.embedding(buckets) # embedding vector for each [i,j] pair, shape = [seq_len, seq_len, embedding_dim=num_heads]
|
|
110
|
+
return values.permute(2, 0, 1).unsqueeze(0) # [1, num_heads, seq_len, seq_len]
|
|
111
|
+
|
|
112
|
+
class HSTUPointwiseAttention(nn.Module):
|
|
113
|
+
"""
|
|
114
|
+
Pointwise aggregation attention that implements HSTU without softmax:
|
|
115
|
+
1) [U, V, Q, K] = split( φ1(f1(X)) ), U: gate, V: value, Q: query, K: key
|
|
116
|
+
2) AV = φ2(QK^T + rab) V / N, av is attention-weighted value
|
|
117
|
+
3) Y = f2( Norm(AV) ⊙ U ), y is output
|
|
118
|
+
φ1, φ2 use SiLU; Norm uses LayerNorm.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
def __init__(
|
|
122
|
+
self,
|
|
123
|
+
d_model: int,
|
|
124
|
+
num_heads: int,
|
|
125
|
+
dropout: float = 0.1,
|
|
126
|
+
alpha: float | None = None
|
|
127
|
+
):
|
|
128
|
+
super().__init__()
|
|
129
|
+
if d_model % num_heads != 0:
|
|
130
|
+
raise ValueError(f"[HSTUPointwiseAttention Error] d_model({d_model}) % num_heads({num_heads}) != 0")
|
|
131
|
+
|
|
132
|
+
self.d_model = d_model
|
|
133
|
+
self.num_heads = num_heads
|
|
134
|
+
self.d_head = d_model // num_heads
|
|
135
|
+
self.alpha = alpha if alpha is not None else (self.d_head ** -0.5)
|
|
136
|
+
# project input to 4 * d_model for U, V, Q, K
|
|
137
|
+
self.in_proj = nn.Linear(d_model, 4 * d_model, bias=True)
|
|
138
|
+
# project output back to d_model
|
|
139
|
+
self.out_proj = nn.Linear(d_model, d_model, bias=True)
|
|
140
|
+
self.dropout = nn.Dropout(dropout)
|
|
141
|
+
self.norm = nn.LayerNorm(d_model)
|
|
142
|
+
|
|
143
|
+
def _reshape_heads(self, x: torch.Tensor) -> torch.Tensor:
|
|
144
|
+
"""
|
|
145
|
+
[B, T, D] -> [B, H, T, d_head]
|
|
146
|
+
"""
|
|
147
|
+
B, T, D = x.shape
|
|
148
|
+
return x.view(B, T, self.num_heads, self.d_head).transpose(1, 2)
|
|
149
|
+
|
|
150
|
+
def forward(
|
|
151
|
+
self,
|
|
152
|
+
x: torch.Tensor,
|
|
153
|
+
attn_mask: Optional[torch.Tensor] = None, # [T, T] with 0 or -inf
|
|
154
|
+
key_padding_mask: Optional[torch.Tensor] = None, # [B, T], True = pad
|
|
155
|
+
rab: Optional[torch.Tensor] = None, # [1, H, T, T] or None
|
|
156
|
+
) -> torch.Tensor:
|
|
157
|
+
B, T, D = x.shape
|
|
158
|
+
|
|
159
|
+
# Eq.(1): [U, V, Q, K] = split( φ1(f1(X)) )
|
|
160
|
+
h = F.silu(self.in_proj(x)) # [B, T, 4D]
|
|
161
|
+
U, V, Q, K = h.chunk(4, dim=-1) # each [B, T, D]
|
|
162
|
+
|
|
163
|
+
Qh = self._reshape_heads(Q) # [B, H, T, d_head]
|
|
164
|
+
Kh = self._reshape_heads(K) # [B, H, T, d_head]
|
|
165
|
+
Vh = self._reshape_heads(V) # [B, H, T, d_head]
|
|
166
|
+
Uh = self._reshape_heads(U) # [B, H, T, d_head]
|
|
167
|
+
|
|
168
|
+
# attention logits: QK^T (without 1/sqrt(d) and softmax)
|
|
169
|
+
logits = torch.matmul(Qh, Kh.transpose(-2, -1)) * self.alpha # [B, H, T, T]
|
|
170
|
+
|
|
171
|
+
# add relative position bias (rab^p), and future extensible rab^t
|
|
172
|
+
if rab is not None:
|
|
173
|
+
# rab: [1, H, T, T] or [B, H, T, T]
|
|
174
|
+
logits = logits + rab
|
|
175
|
+
|
|
176
|
+
# construct an "allowed" mask to calculate N
|
|
177
|
+
# 1 indicates that the (query i, key j) pair is a valid attention pair; 0 indicates it is masked out
|
|
178
|
+
allowed = torch.ones_like(logits, dtype=torch.float) # [B, H, T, T]
|
|
179
|
+
|
|
180
|
+
# causal mask: attn_mask is usually an upper triangular matrix of -inf with shape [T, T]
|
|
181
|
+
if attn_mask is not None:
|
|
182
|
+
allowed = allowed * (attn_mask.view(1, 1, T, T) == 0).float()
|
|
183
|
+
logits = logits + attn_mask.view(1, 1, T, T)
|
|
184
|
+
|
|
185
|
+
# padding mask: key_padding_mask is usually [B, T], True = pad
|
|
186
|
+
if key_padding_mask is not None:
|
|
187
|
+
# valid: 1 for non-pad, 0 for pad
|
|
188
|
+
valid = (~key_padding_mask).float() # [B, T]
|
|
189
|
+
valid = valid.view(B, 1, 1, T) # [B, 1, 1, T]
|
|
190
|
+
allowed = allowed * valid
|
|
191
|
+
logits = logits.masked_fill(valid == 0, float("-inf"))
|
|
192
|
+
|
|
193
|
+
# Eq.(2): A(X)V(X) = φ2(QK^T + rab) V(X) / N
|
|
194
|
+
attn = F.silu(logits) # [B, H, T, T]
|
|
195
|
+
denom = allowed.sum(dim=-1, keepdim=True) # [B, H, T, 1]
|
|
196
|
+
denom = denom.clamp(min=1.0)
|
|
197
|
+
|
|
198
|
+
attn = attn / denom # [B, H, T, T]
|
|
199
|
+
AV = torch.matmul(attn, Vh) # [B, H, T, d_head]
|
|
200
|
+
AV = AV.transpose(1, 2).contiguous().view(B, T, D) # reshape back to [B, T, D]
|
|
201
|
+
U_flat = Uh.transpose(1, 2).contiguous().view(B, T, D)
|
|
202
|
+
y = self.out_proj(self.dropout(self.norm(AV) * U_flat)) # [B, T, D]
|
|
203
|
+
return y
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class HSTULayer(nn.Module):
|
|
207
|
+
"""
|
|
208
|
+
HSTUPointwiseAttention + Residual Connection
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
def __init__(
|
|
212
|
+
self,
|
|
213
|
+
d_model: int,
|
|
214
|
+
num_heads: int,
|
|
215
|
+
dropout: float = 0.1,
|
|
216
|
+
use_rab_pos: bool = True,
|
|
217
|
+
rab_num_buckets: int = 32,
|
|
218
|
+
rab_max_distance: int = 128,
|
|
219
|
+
):
|
|
220
|
+
super().__init__()
|
|
221
|
+
self.attn = HSTUPointwiseAttention(d_model=d_model, num_heads=num_heads, dropout=dropout)
|
|
222
|
+
self.dropout = nn.Dropout(dropout)
|
|
223
|
+
self.use_rab_pos = use_rab_pos
|
|
224
|
+
self.rel_pos_bias = (RelativePositionBias(num_heads=num_heads, num_buckets=rab_num_buckets, max_distance=rab_max_distance) if use_rab_pos else None)
|
|
225
|
+
|
|
226
|
+
def forward(
|
|
227
|
+
self,
|
|
228
|
+
x: torch.Tensor,
|
|
229
|
+
attn_mask: Optional[torch.Tensor] = None,
|
|
230
|
+
key_padding_mask: Optional[torch.Tensor] = None,
|
|
231
|
+
) -> torch.Tensor:
|
|
232
|
+
"""
|
|
233
|
+
x: [B, T, D]
|
|
234
|
+
"""
|
|
235
|
+
B, T, D = x.shape
|
|
236
|
+
device = x.device
|
|
237
|
+
rab = None
|
|
238
|
+
if self.use_rab_pos:
|
|
239
|
+
rab = self.rel_pos_bias(seq_len=T, device=device) # [1, H, T, T]
|
|
240
|
+
out = self.attn(x=x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, rab=rab)
|
|
241
|
+
return x + self.dropout(out)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
class HSTU(BaseModel):
|
|
245
|
+
"""
|
|
246
|
+
HSTU encoder for next-item prediction in a causal, generative setup.
|
|
247
|
+
Pipeline:
|
|
248
|
+
1) Embed tokens + positions from the behavior history
|
|
249
|
+
2) Apply stacked HSTU layers with causal mask and optional RAB
|
|
250
|
+
3) Use the last valid position to produce next-item logits via tied LM head
|
|
251
|
+
"""
|
|
252
|
+
|
|
253
|
+
@property
|
|
254
|
+
def model_name(self) -> str:
|
|
255
|
+
return "HSTU"
|
|
256
|
+
|
|
257
|
+
@property
|
|
258
|
+
def task_type(self) -> str:
|
|
259
|
+
return "multiclass"
|
|
260
|
+
|
|
261
|
+
def __init__(
|
|
262
|
+
self,
|
|
263
|
+
sequence_features: list[SequenceFeature],
|
|
264
|
+
dense_features: Optional[list[DenseFeature]] = None,
|
|
265
|
+
sparse_features: Optional[list[SparseFeature]] = None,
|
|
266
|
+
d_model: Optional[int] = None,
|
|
267
|
+
num_heads: int = 8,
|
|
268
|
+
num_layers: int = 4,
|
|
269
|
+
max_seq_len: int = 200,
|
|
270
|
+
dropout: float = 0.1,
|
|
271
|
+
# RAB settings
|
|
272
|
+
use_rab_pos: bool = True,
|
|
273
|
+
rab_num_buckets: int = 32,
|
|
274
|
+
rab_max_distance: int = 128,
|
|
275
|
+
|
|
276
|
+
tie_embeddings: bool = True,
|
|
277
|
+
target: Optional[list[str] | str] = None,
|
|
278
|
+
optimizer: str = "adam",
|
|
279
|
+
optimizer_params: Optional[dict] = None,
|
|
280
|
+
scheduler: Optional[str] = None,
|
|
281
|
+
scheduler_params: Optional[dict] = None,
|
|
282
|
+
loss_params: Optional[dict] = None,
|
|
283
|
+
embedding_l1_reg: float = 0.0,
|
|
284
|
+
dense_l1_reg: float = 0.0,
|
|
285
|
+
embedding_l2_reg: float = 0.0,
|
|
286
|
+
dense_l2_reg: float = 0.0,
|
|
287
|
+
device: str = "cpu",
|
|
288
|
+
**kwargs,
|
|
289
|
+
):
|
|
290
|
+
if not sequence_features:
|
|
291
|
+
raise ValueError("[HSTU Error] HSTU requires at least one SequenceFeature (user behavior history).")
|
|
292
|
+
|
|
293
|
+
# demo version: use the first SequenceFeature as the main sequence
|
|
294
|
+
self.history_feature = sequence_features[0]
|
|
295
|
+
|
|
296
|
+
hidden_dim = d_model or max(int(getattr(self.history_feature, "embedding_dim", 0) or 0), 32)
|
|
297
|
+
# Make hidden_dim divisible by num_heads
|
|
298
|
+
if hidden_dim % num_heads != 0:
|
|
299
|
+
hidden_dim = num_heads * math.ceil(hidden_dim / num_heads)
|
|
300
|
+
|
|
301
|
+
self.padding_idx = self.history_feature.padding_idx if self.history_feature.padding_idx is not None else 0
|
|
302
|
+
self.vocab_size = self.history_feature.vocab_size
|
|
303
|
+
self.max_seq_len = max_seq_len
|
|
304
|
+
|
|
305
|
+
super().__init__(
|
|
306
|
+
dense_features=dense_features,
|
|
307
|
+
sparse_features=sparse_features,
|
|
308
|
+
sequence_features=sequence_features,
|
|
309
|
+
target=target,
|
|
310
|
+
task=self.task_type,
|
|
311
|
+
device=device,
|
|
312
|
+
embedding_l1_reg=embedding_l1_reg,
|
|
313
|
+
dense_l1_reg=dense_l1_reg,
|
|
314
|
+
embedding_l2_reg=embedding_l2_reg,
|
|
315
|
+
dense_l2_reg=dense_l2_reg,
|
|
316
|
+
**kwargs,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
# token & position embedding (paper usually includes pos embedding / RAB in encoder)
|
|
320
|
+
self.token_embedding = nn.Embedding(
|
|
321
|
+
num_embeddings=self.vocab_size,
|
|
322
|
+
embedding_dim=hidden_dim,
|
|
323
|
+
padding_idx=self.padding_idx,
|
|
324
|
+
)
|
|
325
|
+
self.position_embedding = nn.Embedding(max_seq_len, hidden_dim)
|
|
326
|
+
self.input_dropout = nn.Dropout(dropout)
|
|
327
|
+
|
|
328
|
+
# HSTU layers
|
|
329
|
+
self.layers = nn.ModuleList([HSTULayer(d_model=hidden_dim, num_heads=num_heads, dropout=dropout, use_rab_pos=use_rab_pos,
|
|
330
|
+
rab_num_buckets=rab_num_buckets, rab_max_distance=rab_max_distance) for _ in range(num_layers)])
|
|
331
|
+
|
|
332
|
+
self.final_norm = nn.LayerNorm(hidden_dim)
|
|
333
|
+
self.lm_head = nn.Linear(hidden_dim, self.vocab_size, bias=False)
|
|
334
|
+
if tie_embeddings:
|
|
335
|
+
self.lm_head.weight = self.token_embedding.weight
|
|
336
|
+
|
|
337
|
+
# causal mask buffer
|
|
338
|
+
self.register_buffer("causal_mask", torch.empty(0), persistent=False)
|
|
339
|
+
self.ignore_index = self.padding_idx if self.padding_idx is not None else -100
|
|
340
|
+
|
|
341
|
+
optimizer_params = optimizer_params or {}
|
|
342
|
+
scheduler_params = scheduler_params or {}
|
|
343
|
+
loss_params = loss_params or {}
|
|
344
|
+
loss_params.setdefault("ignore_index", self.ignore_index)
|
|
345
|
+
|
|
346
|
+
self.compile(optimizer=optimizer, optimizer_params=optimizer_params, scheduler=scheduler, scheduler_params=scheduler_params, loss="crossentropy", loss_params=loss_params)
|
|
347
|
+
self._register_regularization_weights(embedding_attr="token_embedding", include_modules=["layers", "lm_head"])
|
|
348
|
+
|
|
349
|
+
def _build_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
|
|
350
|
+
"""
|
|
351
|
+
build causal mask of shape [T, T]: upper triangle is -inf, others are 0.
|
|
352
|
+
This will be added to the logits to simulate causal structure.
|
|
353
|
+
"""
|
|
354
|
+
if self.causal_mask.numel() == 0 or self.causal_mask.size(0) < seq_len:
|
|
355
|
+
mask = torch.full((seq_len, seq_len), float("-inf"), device=device)
|
|
356
|
+
mask = torch.triu(mask, diagonal=1)
|
|
357
|
+
self.causal_mask = mask
|
|
358
|
+
return self.causal_mask[:seq_len, :seq_len]
|
|
359
|
+
|
|
360
|
+
def _trim_sequence(self, seq: torch.Tensor) -> torch.Tensor:
|
|
361
|
+
if seq.size(1) <= self.max_seq_len:
|
|
362
|
+
return seq
|
|
363
|
+
return seq[:, -self.max_seq_len :]
|
|
364
|
+
|
|
365
|
+
def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
366
|
+
seq = x[self.history_feature.name].long() # [B, T_raw]
|
|
367
|
+
seq = self._trim_sequence(seq) # [B, T]
|
|
368
|
+
|
|
369
|
+
B, T = seq.shape
|
|
370
|
+
device = seq.device
|
|
371
|
+
# position ids: [B, T]
|
|
372
|
+
pos_ids = torch.arange(T, device=device).unsqueeze(0).expand(B, -1)
|
|
373
|
+
token_emb = self.token_embedding(seq) # [B, T, D]
|
|
374
|
+
pos_emb = self.position_embedding(pos_ids) # [B, T, D]
|
|
375
|
+
hidden_states = self.input_dropout(token_emb + pos_emb)
|
|
376
|
+
|
|
377
|
+
# padding mask:True = pad
|
|
378
|
+
padding_mask = seq.eq(self.padding_idx) # [B, T]
|
|
379
|
+
attn_mask = self._build_causal_mask(seq_len=T, device=device) # [T, T]
|
|
380
|
+
|
|
381
|
+
for layer in self.layers:
|
|
382
|
+
hidden_states = layer(x=hidden_states, attn_mask=attn_mask, key_padding_mask=padding_mask)
|
|
383
|
+
hidden_states = self.final_norm(hidden_states) # [B, T, D]
|
|
384
|
+
|
|
385
|
+
valid_lengths = (~padding_mask).sum(dim=1) # [B]
|
|
386
|
+
last_index = (valid_lengths - 1).clamp(min=0)
|
|
387
|
+
last_hidden = hidden_states[torch.arange(B, device=device), last_index] # [B, D]
|
|
388
|
+
|
|
389
|
+
logits = self.lm_head(last_hidden) # [B, vocab_size]
|
|
390
|
+
return logits
|
|
391
|
+
|
|
392
|
+
def compute_loss(self, y_pred, y_true):
|
|
393
|
+
"""
|
|
394
|
+
y_true: [B] or [B, 1], the id of the next item.
|
|
395
|
+
"""
|
|
396
|
+
if y_true is None:
|
|
397
|
+
raise ValueError("[HSTU-compute_loss] Training requires y_true (next item id).")
|
|
398
|
+
labels = y_true.view(-1).long()
|
|
399
|
+
return self.loss_fn[0](y_pred, labels)
|
nextrec/utils/optimizer.py
CHANGED
|
@@ -10,7 +10,7 @@ from typing import Iterable
|
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
def get_optimizer(
|
|
13
|
-
optimizer: str = "adam",
|
|
13
|
+
optimizer: str | torch.optim.Optimizer = "adam",
|
|
14
14
|
params: Iterable[torch.nn.Parameter] | None = None,
|
|
15
15
|
**optimizer_params
|
|
16
16
|
):
|
|
@@ -51,7 +51,11 @@ def get_optimizer(
|
|
|
51
51
|
return optimizer_fn
|
|
52
52
|
|
|
53
53
|
|
|
54
|
-
def get_scheduler(
|
|
54
|
+
def get_scheduler(
|
|
55
|
+
scheduler: str | torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | type[torch.optim.lr_scheduler.LRScheduler] | None,
|
|
56
|
+
optimizer,
|
|
57
|
+
**scheduler_params
|
|
58
|
+
):
|
|
55
59
|
"""
|
|
56
60
|
Get learning rate scheduler function.
|
|
57
61
|
|
|
@@ -66,7 +70,7 @@ def get_scheduler(scheduler, optimizer, **scheduler_params):
|
|
|
66
70
|
scheduler_fn = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **scheduler_params)
|
|
67
71
|
else:
|
|
68
72
|
raise NotImplementedError(f"Unsupported scheduler: {scheduler}")
|
|
69
|
-
elif isinstance(scheduler, torch.optim.lr_scheduler._LRScheduler):
|
|
73
|
+
elif isinstance(scheduler, (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.LRScheduler)):
|
|
70
74
|
scheduler_fn = scheduler
|
|
71
75
|
else:
|
|
72
76
|
raise TypeError(f"Invalid scheduler type: {type(scheduler)}")
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nextrec
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.2
|
|
4
4
|
Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
|
|
5
5
|
Project-URL: Homepage, https://github.com/zerolovesea/NextRec
|
|
6
6
|
Project-URL: Repository, https://github.com/zerolovesea/NextRec
|
|
@@ -63,7 +63,7 @@ Description-Content-Type: text/markdown
|
|
|
63
63
|

|
|
64
64
|

|
|
65
65
|

|
|
66
|
-

|
|
67
67
|
|
|
68
68
|
English | [中文文档](README_zh.md)
|
|
69
69
|
|
|
@@ -75,13 +75,19 @@ English | [中文文档](README_zh.md)
|
|
|
75
75
|
|
|
76
76
|
NextRec is a modern recommendation framework built on PyTorch, delivering a unified experience for modeling, training, and evaluation. It follows a modular design with rich model implementations, data-processing utilities, and engineering-ready training components. NextRec focuses on large-scale industrial recall scenarios on Spark clusters, training on massive offline parquet features.
|
|
77
77
|
|
|
78
|
-
|
|
78
|
+
## Why NextRec
|
|
79
79
|
|
|
80
80
|
- **Unified feature engineering & data pipeline**: Dense/Sparse/Sequence feature definitions, persistent DataProcessor, and batch-optimized RecDataLoader, matching offline feature training/inference in industrial big-data settings.
|
|
81
81
|
- **Multi-scenario coverage**: Ranking (CTR/CVR), retrieval, multi-task learning, and more marketing/rec models, with a continuously expanding model zoo.
|
|
82
82
|
- **Developer-friendly experience**: Stream processing/training/inference for csv/parquet/pathlike data, plus GPU/MPS acceleration and visualization support.
|
|
83
83
|
- **Efficient training & evaluation**: Standardized engine with optimizers, LR schedulers, early stopping, checkpoints, and detailed logging out of the box.
|
|
84
84
|
|
|
85
|
+
## Architecture
|
|
86
|
+
|
|
87
|
+
NextRec adopts a modular and low-coupling engineering design, enabling full-pipeline reusability and scalability across data processing → model construction → training & evaluation → inference & deployment. Its core components include: a Feature-Spec-driven Embedding architecture, the BaseModel abstraction, a set of independent reusable Layers, a unified DataLoader for both training and inference, and a ready-to-use Model Zoo.
|
|
88
|
+
|
|
89
|
+

|
|
90
|
+
|
|
85
91
|
> The project borrows ideas from excellent open-source rec libraries. Early layers referenced [torch-rechub](https://github.com/datawhalechina/torch-rechub) but have been replaced with in-house implementations. torch-rechub remains mature in architecture and models; the author contributed a bit there—feel free to check it out.
|
|
86
92
|
|
|
87
93
|
---
|
|
@@ -104,7 +110,7 @@ To dive deeper, Jupyter notebooks are available:
|
|
|
104
110
|
- [Hands on the NextRec framework](/tutorials/notebooks/en/Hands%20on%20nextrec.ipynb)
|
|
105
111
|
- [Using the data processor for preprocessing](/tutorials/notebooks/en/Hands%20on%20dataprocessor.ipynb)
|
|
106
112
|
|
|
107
|
-
> Current version [0.3.
|
|
113
|
+
> Current version [0.3.2]: the matching module is not fully polished yet and may have compatibility issues or unexpected errors. Please raise an issue if you run into problems.
|
|
108
114
|
|
|
109
115
|
## 5-Minute Quick Start
|
|
110
116
|
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
nextrec/__init__.py,sha256=CvocnY2uBp0cjNkhrT6ogw0q2bN9s1GNp754FLO-7lo,1117
|
|
2
|
-
nextrec/__version__.py,sha256=
|
|
2
|
+
nextrec/__version__.py,sha256=vNiWJ14r_cw5t_7UDqDQIVZvladKFGyHH2avsLpN7Vg,22
|
|
3
3
|
nextrec/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
4
|
nextrec/basic/activation.py,sha256=1qs9pq4hT3BUxIiYdYs57axMCm4-JyOBFQ6x7xkHTwM,2849
|
|
5
5
|
nextrec/basic/callback.py,sha256=wwh0I2kKYyywCB-sG9eQXShlpXFJIo75qApJmnI5p6c,1036
|
|
6
6
|
nextrec/basic/features.py,sha256=JtB63jqOIL7zZ5zoTgvEM4fEoqexMz0SMTmowTURk1I,4626
|
|
7
7
|
nextrec/basic/layers.py,sha256=zIa8QsPkOOovjrMAUC94SfhSVTS4R_CXySBr5KAk6i4,24686
|
|
8
|
-
nextrec/basic/loggers.py,sha256=
|
|
8
|
+
nextrec/basic/loggers.py,sha256=VNed0LagpoPSUl2itW8hHT-BSqJHTlQY5pVxIVmm6AE,3733
|
|
9
9
|
nextrec/basic/metrics.py,sha256=YFOaUexHJncc6sPbw2LF2sBnFp-3PLMrjR3aQbBDpGs,20891
|
|
10
|
-
nextrec/basic/model.py,sha256=
|
|
10
|
+
nextrec/basic/model.py,sha256=Doq5KOYrUHavpSa8RkHbT98ZhbFGRpRsA_9K1A5gU9c,73453
|
|
11
11
|
nextrec/basic/session.py,sha256=oaATn-nzbJ9A6SGbMut9xLV_NSh9_1KmVDeNauS06Ps,4767
|
|
12
12
|
nextrec/data/__init__.py,sha256=COaTyiARV7hEQTT3e74uyCBGmHFQ9rhe6g6Shc-Ualw,1064
|
|
13
13
|
nextrec/data/data_utils.py,sha256=H-isIrs2FPyLSTe7IiFUkn6SQKfO0BkGKmj43C9yLGY,7602
|
|
@@ -18,7 +18,8 @@ nextrec/loss/listwise.py,sha256=gxDbO1td5IeS28jKzdE35o1KAYBRdCYoMzyZzfNLhc0,5689
|
|
|
18
18
|
nextrec/loss/loss_utils.py,sha256=uZ4m9ChLr-UgIc5Yxm1LjwXDDepApQ-Fas8njweZ9qg,2641
|
|
19
19
|
nextrec/loss/pairwise.py,sha256=MN_3Pk6Nj8KCkmUqGT5cmyx1_nQa3TIx_kxXT_HB58c,3396
|
|
20
20
|
nextrec/loss/pointwise.py,sha256=shgdRJwTV7vAnVxHSffOJU4TPQeKyrwudQ8y-R10nYM,7144
|
|
21
|
-
nextrec/models/generative/
|
|
21
|
+
nextrec/models/generative/__init__.py,sha256=vo8-DloD74cKc1moSH-4GYG99w8Yi8YPGPxh8XDJPoc,50
|
|
22
|
+
nextrec/models/generative/hstu.py,sha256=qTS05XQBjgC5K34A07DSgIITMs1-ADZ8KVb-HEyNh9w,16369
|
|
22
23
|
nextrec/models/generative/tiger.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
23
24
|
nextrec/models/match/__init__.py,sha256=ASZB5abqKPhDbk8NErNNNa0DHuWpsVxvUtyEn5XMx6Y,215
|
|
24
25
|
nextrec/models/match/dssm.py,sha256=e0hUqNLJVwTRVz4F4EiO8KLOOprKRBDtI4ID6Y1Tc60,8232
|
|
@@ -49,8 +50,8 @@ nextrec/utils/__init__.py,sha256=A3mH6M-DmDBWQ1stIIaTsNzvUy_AKaUWtRmrzU5R3FE,429
|
|
|
49
50
|
nextrec/utils/common.py,sha256=YTlJkFCvIH5ExiOvg5pNPdRLUQ-h60BX4xTliaXKDsE,1217
|
|
50
51
|
nextrec/utils/embedding.py,sha256=yxYSdFx0cJITh3Gf-K4SdhwRtKGcI0jOsyBgZ0NLa_c,465
|
|
51
52
|
nextrec/utils/initializer.py,sha256=ffYOs5QuIns_d_-5e40iNtg6s1ftgREJN-ueq_NbDQE,1647
|
|
52
|
-
nextrec/utils/optimizer.py,sha256=
|
|
53
|
-
nextrec-0.3.
|
|
54
|
-
nextrec-0.3.
|
|
55
|
-
nextrec-0.3.
|
|
56
|
-
nextrec-0.3.
|
|
53
|
+
nextrec/utils/optimizer.py,sha256=EUjAGFPeyou_Cv-_2HRvjzut8y_qpAQudc8L2T0k8zw,2706
|
|
54
|
+
nextrec-0.3.2.dist-info/METADATA,sha256=4RFzGjoOmLQUS1wIyJ6edrJgJNZkH9wcxOrQxSLln4w,16319
|
|
55
|
+
nextrec-0.3.2.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
56
|
+
nextrec-0.3.2.dist-info/licenses/LICENSE,sha256=2fQfVKeafywkni7MYHyClC6RGGC3laLTXCNBx-ubtp0,1064
|
|
57
|
+
nextrec-0.3.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|