nextrec 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.3.1"
1
+ __version__ = "0.3.2"
nextrec/basic/loggers.py CHANGED
@@ -106,7 +106,7 @@ def setup_logger(session_id: str | os.PathLike | None = None):
106
106
 
107
107
  console_format = '%(message)s'
108
108
  file_format = '%(asctime)s - %(levelname)s - %(message)s'
109
- date_format = '%H:%M:%S'
109
+ date_format = '%Y-%m-%d %H:%M:%S'
110
110
 
111
111
  logger = logging.getLogger()
112
112
  logger.setLevel(logging.INFO)
nextrec/basic/model.py CHANGED
@@ -216,10 +216,15 @@ class BaseModel(FeatureSpecMixin, nn.Module):
216
216
  return train_loader, valid_split
217
217
 
218
218
  def compile(
219
- self, optimizer="adam", optimizer_params: dict | None = None,
220
- scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None, scheduler_params: dict | None = None,
221
- loss: str | nn.Module | list[str | nn.Module] | None = "bce", loss_params: dict | list[dict] | None = None,
222
- loss_weights: int | float | list[int | float] | None = None,):
219
+ self,
220
+ optimizer: str | torch.optim.Optimizer = "adam",
221
+ optimizer_params: dict | None = None,
222
+ scheduler: str | torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | type[torch.optim.lr_scheduler.LRScheduler] | None = None,
223
+ scheduler_params: dict | None = None,
224
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
225
+ loss_params: dict | list[dict] | None = None,
226
+ loss_weights: int | float | list[int | float] | None = None,
227
+ ):
223
228
  optimizer_params = optimizer_params or {}
224
229
  self._optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
225
230
  self._optimizer_params = optimizer_params
@@ -1081,6 +1086,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
1081
1086
  logger.info(f" Early Stop Patience: {self._early_stop_patience}")
1082
1087
  logger.info(f" Max Gradient Norm: {self._max_gradient_norm}")
1083
1088
  logger.info(f" Session ID: {self.session_id}")
1089
+ logger.info(f" Features Config Path: {self.features_config_path}")
1084
1090
  logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
1085
1091
 
1086
1092
  logger.info("")
@@ -1195,7 +1201,7 @@ class BaseMatchModel(BaseModel):
1195
1201
  def compile(self,
1196
1202
  optimizer: str | torch.optim.Optimizer = "adam",
1197
1203
  optimizer_params: dict | None = None,
1198
- scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
1204
+ scheduler: str | torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | type[torch.optim.lr_scheduler.LRScheduler] | None = None,
1199
1205
  scheduler_params: dict | None = None,
1200
1206
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
1201
1207
  loss_params: dict | list[dict] | None = None):
@@ -0,0 +1,5 @@
1
+ from .hstu import HSTU
2
+
3
+ __all__ = [
4
+ "HSTU",
5
+ ]
@@ -0,0 +1,399 @@
1
+ """
2
+ [Info: this version is not released yet, i need to more research on source code and paper]
3
+ Date: create on 01/12/2025
4
+ Checkpoint: edit on 01/12/2025
5
+ Author: Yang Zhou, zyaztec@gmail.com
6
+ Reference:
7
+ [1] Meta AI. Generative Recommenders (HSTU encoder) — https://github.com/meta-recsys/generative-recommenders
8
+ [2] Ma W, Li P, Chen C, et al. Actions speak louder than words: Trillion-parameter sequential transducers for generative recommendations. arXiv:2402.17152.
9
+
10
+ Hierarchical Sequential Transduction Unit (HSTU) is the core encoder behind
11
+ Meta’s Generative Recommenders. It replaces softmax attention with lightweight
12
+ pointwise activations, enabling extremely deep stacks on long behavior sequences.
13
+
14
+ In each HSTU layer:
15
+ (1) Tokens are projected into four streams U, V, Q, K via a shared feed-forward block
16
+ (2) Softmax-free interactions combine QK^T with Relative Attention Bias (RAB) to encode distance
17
+ (3) Aggregated context is modulated by U-gating and mapped back through an output projection
18
+
19
+ Stacking layers yields an efficient causal encoder for next-item
20
+ generation. With a tied-embedding LM head, HSTU forms
21
+ a full generative recommendation model.
22
+
23
+ Key Advantages:
24
+ - Softmax-free attention scales better on deep/long sequences
25
+ - RAB captures temporal structure without extra attention heads
26
+ - Causal masking and padding-aware normalization fit real logs
27
+ - Weight tying reduces parameters and stabilizes training
28
+ - Serves as a drop-in backbone for generative recommendation
29
+
30
+ HSTU(层次化序列转导单元)是 Meta 生成式推荐的核心编码器,
31
+ 用点式激活替代 softmax 注意力,可在长序列上轻松堆叠深层结构。
32
+
33
+ 单层 HSTU 的主要步骤:
34
+ (1) 将输入一次性映射到 U、V、Q、K 四条通路
35
+ (2) 利用不含 softmax 的 QK^T 结合相对位置偏置(RAB)建模距离信息
36
+ (3) 用 U 对聚合上下文进行门控,再映射回输出空间
37
+
38
+ 多层堆叠后,可得到高效的因果编码器;与绑权 LM 头配合即可完成 next-item 预测。
39
+
40
+ 主要优势:
41
+ - 摆脱 softmax,在长序列、深层模型上更易扩展
42
+ - 相对位置偏置稳健刻画时序结构
43
+ - 因果 mask 与 padding 感知归一化贴合真实日志
44
+ - 绑权输出头降低参数量并提升训练稳定性
45
+ - 直接作为生成式推荐的骨干网络
46
+ """
47
+
48
+ from __future__ import annotations
49
+
50
+ import math
51
+ from typing import Optional
52
+
53
+ import torch
54
+ import torch.nn as nn
55
+ import torch.nn.functional as F
56
+
57
+ from nextrec.basic.model import BaseModel
58
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
59
+
60
+
61
+ def _relative_position_bucket(
62
+ relative_position: torch.Tensor,
63
+ num_buckets: int = 32,
64
+ max_distance: int = 128,
65
+ ) -> torch.Tensor:
66
+ """
67
+ map the relative position (i-j) to a bucket in [0, num_buckets).
68
+ """
69
+ # only need the negative part for causal attention
70
+ n = -relative_position
71
+ n = torch.clamp(n, min=0)
72
+
73
+ # when the distance is small, keep it exact
74
+ max_exact = num_buckets // 2
75
+ is_small = n < max_exact
76
+
77
+ # when the distance is too far, do log scaling
78
+ large_val = max_exact + ((torch.log(n.float() / max_exact + 1e-6) / math.log(max_distance / max_exact)) * (num_buckets - max_exact)).long()
79
+ large_val = torch.clamp(large_val, max=num_buckets - 1)
80
+
81
+ buckets = torch.where(is_small, n.long(), large_val)
82
+ return buckets
83
+
84
+
85
+ class RelativePositionBias(nn.Module):
86
+ """
87
+ Compute relative position bias (RAB) for HSTU attention.
88
+ The input is the sequence length T, output is [1, num_heads, seq_len, seq_len].
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ num_heads: int,
94
+ num_buckets: int = 32,
95
+ max_distance: int = 128,
96
+ ):
97
+ super().__init__()
98
+ self.num_heads = num_heads
99
+ self.num_buckets = num_buckets
100
+ self.max_distance = max_distance
101
+ self.embedding = nn.Embedding(num_buckets, num_heads)
102
+
103
+ def forward(self, seq_len: int, device: torch.device) -> torch.Tensor:
104
+ # positions: [T]
105
+ ctx = torch.arange(seq_len, device=device)[:, None]
106
+ mem = torch.arange(seq_len, device=device)[None, :]
107
+ rel_pos = mem - ctx # a matrix to describe all relative positions for each [i,j] pair, shape = [seq_len, seq_len]
108
+ buckets = _relative_position_bucket(rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance,) # map to buckets
109
+ values = self.embedding(buckets) # embedding vector for each [i,j] pair, shape = [seq_len, seq_len, embedding_dim=num_heads]
110
+ return values.permute(2, 0, 1).unsqueeze(0) # [1, num_heads, seq_len, seq_len]
111
+
112
+ class HSTUPointwiseAttention(nn.Module):
113
+ """
114
+ Pointwise aggregation attention that implements HSTU without softmax:
115
+ 1) [U, V, Q, K] = split( φ1(f1(X)) ), U: gate, V: value, Q: query, K: key
116
+ 2) AV = φ2(QK^T + rab) V / N, av is attention-weighted value
117
+ 3) Y = f2( Norm(AV) ⊙ U ), y is output
118
+ φ1, φ2 use SiLU; Norm uses LayerNorm.
119
+ """
120
+
121
+ def __init__(
122
+ self,
123
+ d_model: int,
124
+ num_heads: int,
125
+ dropout: float = 0.1,
126
+ alpha: float | None = None
127
+ ):
128
+ super().__init__()
129
+ if d_model % num_heads != 0:
130
+ raise ValueError(f"[HSTUPointwiseAttention Error] d_model({d_model}) % num_heads({num_heads}) != 0")
131
+
132
+ self.d_model = d_model
133
+ self.num_heads = num_heads
134
+ self.d_head = d_model // num_heads
135
+ self.alpha = alpha if alpha is not None else (self.d_head ** -0.5)
136
+ # project input to 4 * d_model for U, V, Q, K
137
+ self.in_proj = nn.Linear(d_model, 4 * d_model, bias=True)
138
+ # project output back to d_model
139
+ self.out_proj = nn.Linear(d_model, d_model, bias=True)
140
+ self.dropout = nn.Dropout(dropout)
141
+ self.norm = nn.LayerNorm(d_model)
142
+
143
+ def _reshape_heads(self, x: torch.Tensor) -> torch.Tensor:
144
+ """
145
+ [B, T, D] -> [B, H, T, d_head]
146
+ """
147
+ B, T, D = x.shape
148
+ return x.view(B, T, self.num_heads, self.d_head).transpose(1, 2)
149
+
150
+ def forward(
151
+ self,
152
+ x: torch.Tensor,
153
+ attn_mask: Optional[torch.Tensor] = None, # [T, T] with 0 or -inf
154
+ key_padding_mask: Optional[torch.Tensor] = None, # [B, T], True = pad
155
+ rab: Optional[torch.Tensor] = None, # [1, H, T, T] or None
156
+ ) -> torch.Tensor:
157
+ B, T, D = x.shape
158
+
159
+ # Eq.(1): [U, V, Q, K] = split( φ1(f1(X)) )
160
+ h = F.silu(self.in_proj(x)) # [B, T, 4D]
161
+ U, V, Q, K = h.chunk(4, dim=-1) # each [B, T, D]
162
+
163
+ Qh = self._reshape_heads(Q) # [B, H, T, d_head]
164
+ Kh = self._reshape_heads(K) # [B, H, T, d_head]
165
+ Vh = self._reshape_heads(V) # [B, H, T, d_head]
166
+ Uh = self._reshape_heads(U) # [B, H, T, d_head]
167
+
168
+ # attention logits: QK^T (without 1/sqrt(d) and softmax)
169
+ logits = torch.matmul(Qh, Kh.transpose(-2, -1)) * self.alpha # [B, H, T, T]
170
+
171
+ # add relative position bias (rab^p), and future extensible rab^t
172
+ if rab is not None:
173
+ # rab: [1, H, T, T] or [B, H, T, T]
174
+ logits = logits + rab
175
+
176
+ # construct an "allowed" mask to calculate N
177
+ # 1 indicates that the (query i, key j) pair is a valid attention pair; 0 indicates it is masked out
178
+ allowed = torch.ones_like(logits, dtype=torch.float) # [B, H, T, T]
179
+
180
+ # causal mask: attn_mask is usually an upper triangular matrix of -inf with shape [T, T]
181
+ if attn_mask is not None:
182
+ allowed = allowed * (attn_mask.view(1, 1, T, T) == 0).float()
183
+ logits = logits + attn_mask.view(1, 1, T, T)
184
+
185
+ # padding mask: key_padding_mask is usually [B, T], True = pad
186
+ if key_padding_mask is not None:
187
+ # valid: 1 for non-pad, 0 for pad
188
+ valid = (~key_padding_mask).float() # [B, T]
189
+ valid = valid.view(B, 1, 1, T) # [B, 1, 1, T]
190
+ allowed = allowed * valid
191
+ logits = logits.masked_fill(valid == 0, float("-inf"))
192
+
193
+ # Eq.(2): A(X)V(X) = φ2(QK^T + rab) V(X) / N
194
+ attn = F.silu(logits) # [B, H, T, T]
195
+ denom = allowed.sum(dim=-1, keepdim=True) # [B, H, T, 1]
196
+ denom = denom.clamp(min=1.0)
197
+
198
+ attn = attn / denom # [B, H, T, T]
199
+ AV = torch.matmul(attn, Vh) # [B, H, T, d_head]
200
+ AV = AV.transpose(1, 2).contiguous().view(B, T, D) # reshape back to [B, T, D]
201
+ U_flat = Uh.transpose(1, 2).contiguous().view(B, T, D)
202
+ y = self.out_proj(self.dropout(self.norm(AV) * U_flat)) # [B, T, D]
203
+ return y
204
+
205
+
206
+ class HSTULayer(nn.Module):
207
+ """
208
+ HSTUPointwiseAttention + Residual Connection
209
+ """
210
+
211
+ def __init__(
212
+ self,
213
+ d_model: int,
214
+ num_heads: int,
215
+ dropout: float = 0.1,
216
+ use_rab_pos: bool = True,
217
+ rab_num_buckets: int = 32,
218
+ rab_max_distance: int = 128,
219
+ ):
220
+ super().__init__()
221
+ self.attn = HSTUPointwiseAttention(d_model=d_model, num_heads=num_heads, dropout=dropout)
222
+ self.dropout = nn.Dropout(dropout)
223
+ self.use_rab_pos = use_rab_pos
224
+ self.rel_pos_bias = (RelativePositionBias(num_heads=num_heads, num_buckets=rab_num_buckets, max_distance=rab_max_distance) if use_rab_pos else None)
225
+
226
+ def forward(
227
+ self,
228
+ x: torch.Tensor,
229
+ attn_mask: Optional[torch.Tensor] = None,
230
+ key_padding_mask: Optional[torch.Tensor] = None,
231
+ ) -> torch.Tensor:
232
+ """
233
+ x: [B, T, D]
234
+ """
235
+ B, T, D = x.shape
236
+ device = x.device
237
+ rab = None
238
+ if self.use_rab_pos:
239
+ rab = self.rel_pos_bias(seq_len=T, device=device) # [1, H, T, T]
240
+ out = self.attn(x=x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, rab=rab)
241
+ return x + self.dropout(out)
242
+
243
+
244
+ class HSTU(BaseModel):
245
+ """
246
+ HSTU encoder for next-item prediction in a causal, generative setup.
247
+ Pipeline:
248
+ 1) Embed tokens + positions from the behavior history
249
+ 2) Apply stacked HSTU layers with causal mask and optional RAB
250
+ 3) Use the last valid position to produce next-item logits via tied LM head
251
+ """
252
+
253
+ @property
254
+ def model_name(self) -> str:
255
+ return "HSTU"
256
+
257
+ @property
258
+ def task_type(self) -> str:
259
+ return "multiclass"
260
+
261
+ def __init__(
262
+ self,
263
+ sequence_features: list[SequenceFeature],
264
+ dense_features: Optional[list[DenseFeature]] = None,
265
+ sparse_features: Optional[list[SparseFeature]] = None,
266
+ d_model: Optional[int] = None,
267
+ num_heads: int = 8,
268
+ num_layers: int = 4,
269
+ max_seq_len: int = 200,
270
+ dropout: float = 0.1,
271
+ # RAB settings
272
+ use_rab_pos: bool = True,
273
+ rab_num_buckets: int = 32,
274
+ rab_max_distance: int = 128,
275
+
276
+ tie_embeddings: bool = True,
277
+ target: Optional[list[str] | str] = None,
278
+ optimizer: str = "adam",
279
+ optimizer_params: Optional[dict] = None,
280
+ scheduler: Optional[str] = None,
281
+ scheduler_params: Optional[dict] = None,
282
+ loss_params: Optional[dict] = None,
283
+ embedding_l1_reg: float = 0.0,
284
+ dense_l1_reg: float = 0.0,
285
+ embedding_l2_reg: float = 0.0,
286
+ dense_l2_reg: float = 0.0,
287
+ device: str = "cpu",
288
+ **kwargs,
289
+ ):
290
+ if not sequence_features:
291
+ raise ValueError("[HSTU Error] HSTU requires at least one SequenceFeature (user behavior history).")
292
+
293
+ # demo version: use the first SequenceFeature as the main sequence
294
+ self.history_feature = sequence_features[0]
295
+
296
+ hidden_dim = d_model or max(int(getattr(self.history_feature, "embedding_dim", 0) or 0), 32)
297
+ # Make hidden_dim divisible by num_heads
298
+ if hidden_dim % num_heads != 0:
299
+ hidden_dim = num_heads * math.ceil(hidden_dim / num_heads)
300
+
301
+ self.padding_idx = self.history_feature.padding_idx if self.history_feature.padding_idx is not None else 0
302
+ self.vocab_size = self.history_feature.vocab_size
303
+ self.max_seq_len = max_seq_len
304
+
305
+ super().__init__(
306
+ dense_features=dense_features,
307
+ sparse_features=sparse_features,
308
+ sequence_features=sequence_features,
309
+ target=target,
310
+ task=self.task_type,
311
+ device=device,
312
+ embedding_l1_reg=embedding_l1_reg,
313
+ dense_l1_reg=dense_l1_reg,
314
+ embedding_l2_reg=embedding_l2_reg,
315
+ dense_l2_reg=dense_l2_reg,
316
+ **kwargs,
317
+ )
318
+
319
+ # token & position embedding (paper usually includes pos embedding / RAB in encoder)
320
+ self.token_embedding = nn.Embedding(
321
+ num_embeddings=self.vocab_size,
322
+ embedding_dim=hidden_dim,
323
+ padding_idx=self.padding_idx,
324
+ )
325
+ self.position_embedding = nn.Embedding(max_seq_len, hidden_dim)
326
+ self.input_dropout = nn.Dropout(dropout)
327
+
328
+ # HSTU layers
329
+ self.layers = nn.ModuleList([HSTULayer(d_model=hidden_dim, num_heads=num_heads, dropout=dropout, use_rab_pos=use_rab_pos,
330
+ rab_num_buckets=rab_num_buckets, rab_max_distance=rab_max_distance) for _ in range(num_layers)])
331
+
332
+ self.final_norm = nn.LayerNorm(hidden_dim)
333
+ self.lm_head = nn.Linear(hidden_dim, self.vocab_size, bias=False)
334
+ if tie_embeddings:
335
+ self.lm_head.weight = self.token_embedding.weight
336
+
337
+ # causal mask buffer
338
+ self.register_buffer("causal_mask", torch.empty(0), persistent=False)
339
+ self.ignore_index = self.padding_idx if self.padding_idx is not None else -100
340
+
341
+ optimizer_params = optimizer_params or {}
342
+ scheduler_params = scheduler_params or {}
343
+ loss_params = loss_params or {}
344
+ loss_params.setdefault("ignore_index", self.ignore_index)
345
+
346
+ self.compile(optimizer=optimizer, optimizer_params=optimizer_params, scheduler=scheduler, scheduler_params=scheduler_params, loss="crossentropy", loss_params=loss_params)
347
+ self._register_regularization_weights(embedding_attr="token_embedding", include_modules=["layers", "lm_head"])
348
+
349
+ def _build_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
350
+ """
351
+ build causal mask of shape [T, T]: upper triangle is -inf, others are 0.
352
+ This will be added to the logits to simulate causal structure.
353
+ """
354
+ if self.causal_mask.numel() == 0 or self.causal_mask.size(0) < seq_len:
355
+ mask = torch.full((seq_len, seq_len), float("-inf"), device=device)
356
+ mask = torch.triu(mask, diagonal=1)
357
+ self.causal_mask = mask
358
+ return self.causal_mask[:seq_len, :seq_len]
359
+
360
+ def _trim_sequence(self, seq: torch.Tensor) -> torch.Tensor:
361
+ if seq.size(1) <= self.max_seq_len:
362
+ return seq
363
+ return seq[:, -self.max_seq_len :]
364
+
365
+ def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
366
+ seq = x[self.history_feature.name].long() # [B, T_raw]
367
+ seq = self._trim_sequence(seq) # [B, T]
368
+
369
+ B, T = seq.shape
370
+ device = seq.device
371
+ # position ids: [B, T]
372
+ pos_ids = torch.arange(T, device=device).unsqueeze(0).expand(B, -1)
373
+ token_emb = self.token_embedding(seq) # [B, T, D]
374
+ pos_emb = self.position_embedding(pos_ids) # [B, T, D]
375
+ hidden_states = self.input_dropout(token_emb + pos_emb)
376
+
377
+ # padding mask:True = pad
378
+ padding_mask = seq.eq(self.padding_idx) # [B, T]
379
+ attn_mask = self._build_causal_mask(seq_len=T, device=device) # [T, T]
380
+
381
+ for layer in self.layers:
382
+ hidden_states = layer(x=hidden_states, attn_mask=attn_mask, key_padding_mask=padding_mask)
383
+ hidden_states = self.final_norm(hidden_states) # [B, T, D]
384
+
385
+ valid_lengths = (~padding_mask).sum(dim=1) # [B]
386
+ last_index = (valid_lengths - 1).clamp(min=0)
387
+ last_hidden = hidden_states[torch.arange(B, device=device), last_index] # [B, D]
388
+
389
+ logits = self.lm_head(last_hidden) # [B, vocab_size]
390
+ return logits
391
+
392
+ def compute_loss(self, y_pred, y_true):
393
+ """
394
+ y_true: [B] or [B, 1], the id of the next item.
395
+ """
396
+ if y_true is None:
397
+ raise ValueError("[HSTU-compute_loss] Training requires y_true (next item id).")
398
+ labels = y_true.view(-1).long()
399
+ return self.loss_fn[0](y_pred, labels)
@@ -10,7 +10,7 @@ from typing import Iterable
10
10
 
11
11
 
12
12
  def get_optimizer(
13
- optimizer: str = "adam",
13
+ optimizer: str | torch.optim.Optimizer = "adam",
14
14
  params: Iterable[torch.nn.Parameter] | None = None,
15
15
  **optimizer_params
16
16
  ):
@@ -51,7 +51,11 @@ def get_optimizer(
51
51
  return optimizer_fn
52
52
 
53
53
 
54
- def get_scheduler(scheduler, optimizer, **scheduler_params):
54
+ def get_scheduler(
55
+ scheduler: str | torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | type[torch.optim.lr_scheduler.LRScheduler] | None,
56
+ optimizer,
57
+ **scheduler_params
58
+ ):
55
59
  """
56
60
  Get learning rate scheduler function.
57
61
 
@@ -66,7 +70,7 @@ def get_scheduler(scheduler, optimizer, **scheduler_params):
66
70
  scheduler_fn = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **scheduler_params)
67
71
  else:
68
72
  raise NotImplementedError(f"Unsupported scheduler: {scheduler}")
69
- elif isinstance(scheduler, torch.optim.lr_scheduler._LRScheduler):
73
+ elif isinstance(scheduler, (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.LRScheduler)):
70
74
  scheduler_fn = scheduler
71
75
  else:
72
76
  raise TypeError(f"Invalid scheduler type: {type(scheduler)}")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nextrec
3
- Version: 0.3.1
3
+ Version: 0.3.2
4
4
  Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
5
5
  Project-URL: Homepage, https://github.com/zerolovesea/NextRec
6
6
  Project-URL: Repository, https://github.com/zerolovesea/NextRec
@@ -63,7 +63,7 @@ Description-Content-Type: text/markdown
63
63
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
64
64
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
65
65
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
66
- ![Version](https://img.shields.io/badge/Version-0.3.1-orange.svg)
66
+ ![Version](https://img.shields.io/badge/Version-0.3.2-orange.svg)
67
67
 
68
68
  English | [中文文档](README_zh.md)
69
69
 
@@ -75,13 +75,19 @@ English | [中文文档](README_zh.md)
75
75
 
76
76
  NextRec is a modern recommendation framework built on PyTorch, delivering a unified experience for modeling, training, and evaluation. It follows a modular design with rich model implementations, data-processing utilities, and engineering-ready training components. NextRec focuses on large-scale industrial recall scenarios on Spark clusters, training on massive offline parquet features.
77
77
 
78
- ### Why NextRec
78
+ ## Why NextRec
79
79
 
80
80
  - **Unified feature engineering & data pipeline**: Dense/Sparse/Sequence feature definitions, persistent DataProcessor, and batch-optimized RecDataLoader, matching offline feature training/inference in industrial big-data settings.
81
81
  - **Multi-scenario coverage**: Ranking (CTR/CVR), retrieval, multi-task learning, and more marketing/rec models, with a continuously expanding model zoo.
82
82
  - **Developer-friendly experience**: Stream processing/training/inference for csv/parquet/pathlike data, plus GPU/MPS acceleration and visualization support.
83
83
  - **Efficient training & evaluation**: Standardized engine with optimizers, LR schedulers, early stopping, checkpoints, and detailed logging out of the box.
84
84
 
85
+ ## Architecture
86
+
87
+ NextRec adopts a modular and low-coupling engineering design, enabling full-pipeline reusability and scalability across data processing → model construction → training & evaluation → inference & deployment. Its core components include: a Feature-Spec-driven Embedding architecture, the BaseModel abstraction, a set of independent reusable Layers, a unified DataLoader for both training and inference, and a ready-to-use Model Zoo.
88
+
89
+ ![NextRec Architecture](asserts/nextrec_diagram_en.png)
90
+
85
91
  > The project borrows ideas from excellent open-source rec libraries. Early layers referenced [torch-rechub](https://github.com/datawhalechina/torch-rechub) but have been replaced with in-house implementations. torch-rechub remains mature in architecture and models; the author contributed a bit there—feel free to check it out.
86
92
 
87
93
  ---
@@ -104,7 +110,7 @@ To dive deeper, Jupyter notebooks are available:
104
110
  - [Hands on the NextRec framework](/tutorials/notebooks/en/Hands%20on%20nextrec.ipynb)
105
111
  - [Using the data processor for preprocessing](/tutorials/notebooks/en/Hands%20on%20dataprocessor.ipynb)
106
112
 
107
- > Current version [0.3.1]: the matching module is not fully polished yet and may have compatibility issues or unexpected errors. Please raise an issue if you run into problems.
113
+ > Current version [0.3.2]: the matching module is not fully polished yet and may have compatibility issues or unexpected errors. Please raise an issue if you run into problems.
108
114
 
109
115
  ## 5-Minute Quick Start
110
116
 
@@ -1,13 +1,13 @@
1
1
  nextrec/__init__.py,sha256=CvocnY2uBp0cjNkhrT6ogw0q2bN9s1GNp754FLO-7lo,1117
2
- nextrec/__version__.py,sha256=r4xAFihOf72W9TD-lpMi6ntWSTKTP2SlzKP1ytkjRbI,22
2
+ nextrec/__version__.py,sha256=vNiWJ14r_cw5t_7UDqDQIVZvladKFGyHH2avsLpN7Vg,22
3
3
  nextrec/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  nextrec/basic/activation.py,sha256=1qs9pq4hT3BUxIiYdYs57axMCm4-JyOBFQ6x7xkHTwM,2849
5
5
  nextrec/basic/callback.py,sha256=wwh0I2kKYyywCB-sG9eQXShlpXFJIo75qApJmnI5p6c,1036
6
6
  nextrec/basic/features.py,sha256=JtB63jqOIL7zZ5zoTgvEM4fEoqexMz0SMTmowTURk1I,4626
7
7
  nextrec/basic/layers.py,sha256=zIa8QsPkOOovjrMAUC94SfhSVTS4R_CXySBr5KAk6i4,24686
8
- nextrec/basic/loggers.py,sha256=TtTN5NIH8yqY27R2jXxQxfsTIA8XUBPJakx6Bl2ofhI,3724
8
+ nextrec/basic/loggers.py,sha256=VNed0LagpoPSUl2itW8hHT-BSqJHTlQY5pVxIVmm6AE,3733
9
9
  nextrec/basic/metrics.py,sha256=YFOaUexHJncc6sPbw2LF2sBnFp-3PLMrjR3aQbBDpGs,20891
10
- nextrec/basic/model.py,sha256=X1eH9XAxIQla-hVGKUxqEm7QyZucp_tIbx6FWYTa24M,73140
10
+ nextrec/basic/model.py,sha256=Doq5KOYrUHavpSa8RkHbT98ZhbFGRpRsA_9K1A5gU9c,73453
11
11
  nextrec/basic/session.py,sha256=oaATn-nzbJ9A6SGbMut9xLV_NSh9_1KmVDeNauS06Ps,4767
12
12
  nextrec/data/__init__.py,sha256=COaTyiARV7hEQTT3e74uyCBGmHFQ9rhe6g6Shc-Ualw,1064
13
13
  nextrec/data/data_utils.py,sha256=H-isIrs2FPyLSTe7IiFUkn6SQKfO0BkGKmj43C9yLGY,7602
@@ -18,7 +18,8 @@ nextrec/loss/listwise.py,sha256=gxDbO1td5IeS28jKzdE35o1KAYBRdCYoMzyZzfNLhc0,5689
18
18
  nextrec/loss/loss_utils.py,sha256=uZ4m9ChLr-UgIc5Yxm1LjwXDDepApQ-Fas8njweZ9qg,2641
19
19
  nextrec/loss/pairwise.py,sha256=MN_3Pk6Nj8KCkmUqGT5cmyx1_nQa3TIx_kxXT_HB58c,3396
20
20
  nextrec/loss/pointwise.py,sha256=shgdRJwTV7vAnVxHSffOJU4TPQeKyrwudQ8y-R10nYM,7144
21
- nextrec/models/generative/hstu.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
21
+ nextrec/models/generative/__init__.py,sha256=vo8-DloD74cKc1moSH-4GYG99w8Yi8YPGPxh8XDJPoc,50
22
+ nextrec/models/generative/hstu.py,sha256=qTS05XQBjgC5K34A07DSgIITMs1-ADZ8KVb-HEyNh9w,16369
22
23
  nextrec/models/generative/tiger.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
24
  nextrec/models/match/__init__.py,sha256=ASZB5abqKPhDbk8NErNNNa0DHuWpsVxvUtyEn5XMx6Y,215
24
25
  nextrec/models/match/dssm.py,sha256=e0hUqNLJVwTRVz4F4EiO8KLOOprKRBDtI4ID6Y1Tc60,8232
@@ -49,8 +50,8 @@ nextrec/utils/__init__.py,sha256=A3mH6M-DmDBWQ1stIIaTsNzvUy_AKaUWtRmrzU5R3FE,429
49
50
  nextrec/utils/common.py,sha256=YTlJkFCvIH5ExiOvg5pNPdRLUQ-h60BX4xTliaXKDsE,1217
50
51
  nextrec/utils/embedding.py,sha256=yxYSdFx0cJITh3Gf-K4SdhwRtKGcI0jOsyBgZ0NLa_c,465
51
52
  nextrec/utils/initializer.py,sha256=ffYOs5QuIns_d_-5e40iNtg6s1ftgREJN-ueq_NbDQE,1647
52
- nextrec/utils/optimizer.py,sha256=85ifoy2IQgjPHOqLqr1ho7XBGE_0ry1yEB9efS6C2lM,2446
53
- nextrec-0.3.1.dist-info/METADATA,sha256=bYvcXVXbnD8hAW8Y-cVKj--ngfd1kCo36atLZFhszT8,15808
54
- nextrec-0.3.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
55
- nextrec-0.3.1.dist-info/licenses/LICENSE,sha256=2fQfVKeafywkni7MYHyClC6RGGC3laLTXCNBx-ubtp0,1064
56
- nextrec-0.3.1.dist-info/RECORD,,
53
+ nextrec/utils/optimizer.py,sha256=EUjAGFPeyou_Cv-_2HRvjzut8y_qpAQudc8L2T0k8zw,2706
54
+ nextrec-0.3.2.dist-info/METADATA,sha256=4RFzGjoOmLQUS1wIyJ6edrJgJNZkH9wcxOrQxSLln4w,16319
55
+ nextrec-0.3.2.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
56
+ nextrec-0.3.2.dist-info/licenses/LICENSE,sha256=2fQfVKeafywkni7MYHyClC6RGGC3laLTXCNBx-ubtp0,1064
57
+ nextrec-0.3.2.dist-info/RECORD,,