nextrec 0.4.5__py3-none-any.whl → 0.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -25,9 +25,7 @@ def get_column_data(data: dict | pd.DataFrame, name: str):
25
25
  raise KeyError(f"Unsupported data type for extracting column {name}")
26
26
 
27
27
 
28
- def split_dict_random(
29
- data_dict: dict, test_size: float = 0.2, random_state: int | None = None
30
- ):
28
+ def split_dict_random(data_dict, test_size=0.2, random_state=None):
31
29
 
32
30
  lengths = [len(v) for v in data_dict.values()]
33
31
  if len(set(lengths)) != 1:
@@ -2,10 +2,12 @@
2
2
  Loss utilities for NextRec.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 29/11/2025
5
+ Checkpoint: edit on 17/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
9
+ from typing import Literal
10
+
9
11
  import torch.nn as nn
10
12
 
11
13
  from nextrec.loss.listwise import (
@@ -30,14 +32,81 @@ VALID_TASK_TYPES = [
30
32
  "regression",
31
33
  ]
32
34
 
35
+ # Define all supported loss types
36
+ LossType = Literal[
37
+ # Pointwise losses
38
+ "bce",
39
+ "binary_crossentropy",
40
+ "weighted_bce",
41
+ "focal",
42
+ "focal_loss",
43
+ "cb_focal",
44
+ "class_balanced_focal",
45
+ "crossentropy",
46
+ "ce",
47
+ "mse",
48
+ "mae",
49
+ # Pairwise ranking losses
50
+ "bpr",
51
+ "hinge",
52
+ "triplet",
53
+ # Listwise ranking losses
54
+ "sampled_softmax",
55
+ "softmax",
56
+ "infonce",
57
+ "listnet",
58
+ "listmle",
59
+ "approx_ndcg",
60
+ ]
61
+
33
62
 
34
- def _build_cb_focal(kw):
63
+ def build_cb_focal(kw):
35
64
  if "class_counts" not in kw:
36
65
  raise ValueError("class_balanced_focal requires class_counts")
37
66
  return ClassBalancedFocalLoss(**kw)
38
67
 
39
68
 
40
- def get_loss_fn(loss=None, **kw):
69
+ def get_loss_fn(loss: LossType | nn.Module | None = None, **kw) -> nn.Module:
70
+ """
71
+ Get loss function by name or return the provided loss module.
72
+
73
+ Args:
74
+ loss: Loss function name or nn.Module instance. Supported options:
75
+
76
+ **Pointwise Losses:**
77
+ - "bce", "binary_crossentropy": Binary Cross-Entropy Loss
78
+ - "weighted_bce": Weighted Binary Cross-Entropy Loss
79
+ - "focal", "focal_loss": Focal Loss (for class imbalance)
80
+ - "cb_focal", "class_balanced_focal": Class-Balanced Focal Loss (requires class_counts parameter)
81
+ - "crossentropy", "ce": Cross-Entropy Loss for multi-class classification
82
+ - "mse": Mean Squared Error Loss
83
+ - "mae": Mean Absolute Error Loss
84
+
85
+ **Pairwise Ranking Losses:**
86
+ - "bpr": Bayesian Personalized Ranking Loss
87
+ - "hinge": Hinge Loss
88
+ - "triplet": Triplet Loss
89
+
90
+ **Listwise Ranking Losses:**
91
+ - "sampled_softmax", "softmax": Sampled Softmax Loss
92
+ - "infonce": InfoNCE Loss
93
+ - "listnet": ListNet Loss
94
+ - "listmle": ListMLE Loss
95
+ - "approx_ndcg": Approximate NDCG Loss
96
+
97
+ **kw: Additional keyword arguments passed to the loss function
98
+
99
+ Returns:
100
+ nn.Module: Loss function instance
101
+
102
+ Raises:
103
+ ValueError: If loss is None or unsupported type
104
+
105
+ Examples:
106
+ >>> loss_fn = get_loss_fn("bce")
107
+ >>> loss_fn = get_loss_fn("focal", alpha=0.25, gamma=2.0)
108
+ >>> loss_fn = get_loss_fn("cb_focal", class_counts=[100, 50, 200])
109
+ """
41
110
  if isinstance(loss, nn.Module):
42
111
  return loss
43
112
  if loss is None:
@@ -49,7 +118,7 @@ def get_loss_fn(loss=None, **kw):
49
118
  if loss in ["focal", "focal_loss"]:
50
119
  return FocalLoss(**kw)
51
120
  if loss in ["cb_focal", "class_balanced_focal"]:
52
- return _build_cb_focal(kw)
121
+ return build_cb_focal(kw)
53
122
  if loss in ["crossentropy", "ce"]:
54
123
  return nn.CrossEntropyLoss(**kw)
55
124
  if loss == "mse":
@@ -0,0 +1,16 @@
1
+ """
2
+ Generative Recommendation Models
3
+
4
+ This module contains generative models for recommendation tasks.
5
+ """
6
+
7
+ from nextrec.models.generative.hstu import HSTU
8
+ from nextrec.models.generative.rqvae import (
9
+ RQVAE,
10
+ RQ,
11
+ VQEmbedding,
12
+ BalancedKmeans,
13
+ kmeans,
14
+ )
15
+
16
+ __all__ = ["HSTU", "RQVAE", "RQ", "VQEmbedding", "BalancedKmeans", "kmeans"]
@@ -1,7 +1,7 @@
1
1
  """
2
2
  [Info: this version is not released yet, i need to more research on source code and paper]
3
3
  Date: create on 01/12/2025
4
- Checkpoint: edit on 01/12/2025
4
+ Checkpoint: edit on 11/12/2025
5
5
  Author: Yang Zhou, zyaztec@gmail.com
6
6
  Reference:
7
7
  [1] Meta AI. Generative Recommenders (HSTU encoder) — https://github.com/meta-recsys/generative-recommenders
@@ -55,10 +55,13 @@ import torch.nn as nn
55
55
  import torch.nn.functional as F
56
56
 
57
57
  from nextrec.basic.model import BaseModel
58
+ from nextrec.basic.layers import RMSNorm, EmbeddingLayer
58
59
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
59
60
 
61
+ from nextrec.utils.model import select_features
60
62
 
61
- def _relative_position_bucket(
63
+
64
+ def relative_position_bucket(
62
65
  relative_position: torch.Tensor,
63
66
  num_buckets: int = 32,
64
67
  max_distance: int = 128,
@@ -116,7 +119,7 @@ class RelativePositionBias(nn.Module):
116
119
  rel_pos = (
117
120
  mem - ctx
118
121
  ) # a matrix to describe all relative positions for each [i,j] pair, shape = [seq_len, seq_len]
119
- buckets = _relative_position_bucket(
122
+ buckets = relative_position_bucket(
120
123
  rel_pos,
121
124
  num_buckets=self.num_buckets,
122
125
  max_distance=self.max_distance,
@@ -138,39 +141,40 @@ class HSTUPointwiseAttention(nn.Module):
138
141
 
139
142
  def __init__(
140
143
  self,
141
- d_model: int,
144
+ hidden_dim: int,
142
145
  num_heads: int,
143
146
  dropout: float = 0.1,
144
147
  alpha: float | None = None,
148
+ use_rms_norm: bool = False,
145
149
  ):
146
150
  super().__init__()
147
- if d_model % num_heads != 0:
151
+ if hidden_dim % num_heads != 0:
148
152
  raise ValueError(
149
- f"[HSTUPointwiseAttention Error] d_model({d_model}) % num_heads({num_heads}) != 0"
153
+ f"[HSTUPointwiseAttention Error] hidden_dim({hidden_dim}) % num_heads({num_heads}) != 0"
150
154
  )
151
155
 
152
- self.d_model = d_model
156
+ self.hidden_dim = hidden_dim
153
157
  self.num_heads = num_heads
154
- self.d_head = d_model // num_heads
155
- self.alpha = alpha if alpha is not None else (self.d_head**-0.5)
156
- # project input to 4 * d_model for U, V, Q, K
157
- self.in_proj = nn.Linear(d_model, 4 * d_model, bias=True)
158
- # project output back to d_model
159
- self.out_proj = nn.Linear(d_model, d_model, bias=True)
158
+ self.head_dim = hidden_dim // num_heads
159
+ self.alpha = alpha if alpha is not None else (self.head_dim**-0.5)
160
+ # project input to 4 * hidden_dim for U, V, Q, K
161
+ self.in_proj = nn.Linear(hidden_dim, 4 * hidden_dim, bias=True)
162
+ # project output back to hidden_dim
163
+ self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
160
164
  self.dropout = nn.Dropout(dropout)
161
- self.norm = nn.LayerNorm(d_model)
165
+ self.norm = RMSNorm(hidden_dim) if use_rms_norm else nn.LayerNorm(hidden_dim)
162
166
 
163
- def _reshape_heads(self, x: torch.Tensor) -> torch.Tensor:
167
+ def reshape_heads(self, x: torch.Tensor) -> torch.Tensor:
164
168
  """
165
- [B, T, D] -> [B, H, T, d_head]
169
+ [B, T, D] -> [B, H, T, head_dim]
166
170
  """
167
171
  B, T, D = x.shape
168
- return x.view(B, T, self.num_heads, self.d_head).transpose(1, 2)
172
+ return x.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
169
173
 
170
174
  def forward(
171
175
  self,
172
176
  x: torch.Tensor,
173
- attn_mask: Optional[torch.Tensor] = None, # [T, T] with 0 or -inf
177
+ attention_mask: Optional[torch.Tensor] = None, # [T, T] with 0 or -inf
174
178
  key_padding_mask: Optional[torch.Tensor] = None, # [B, T], True = pad
175
179
  rab: Optional[torch.Tensor] = None, # [1, H, T, T] or None
176
180
  ) -> torch.Tensor:
@@ -180,10 +184,10 @@ class HSTUPointwiseAttention(nn.Module):
180
184
  h = F.silu(self.in_proj(x)) # [B, T, 4D]
181
185
  U, V, Q, K = h.chunk(4, dim=-1) # each [B, T, D]
182
186
 
183
- Qh = self._reshape_heads(Q) # [B, H, T, d_head]
184
- Kh = self._reshape_heads(K) # [B, H, T, d_head]
185
- Vh = self._reshape_heads(V) # [B, H, T, d_head]
186
- Uh = self._reshape_heads(U) # [B, H, T, d_head]
187
+ Qh = self.reshape_heads(Q) # [B, H, T, d_head]
188
+ Kh = self.reshape_heads(K) # [B, H, T, d_head]
189
+ Vh = self.reshape_heads(V) # [B, H, T, d_head]
190
+ Uh = self.reshape_heads(U) # [B, H, T, d_head]
187
191
 
188
192
  # attention logits: QK^T (without 1/sqrt(d) and softmax)
189
193
  logits = torch.matmul(Qh, Kh.transpose(-2, -1)) * self.alpha # [B, H, T, T]
@@ -197,10 +201,10 @@ class HSTUPointwiseAttention(nn.Module):
197
201
  # 1 indicates that the (query i, key j) pair is a valid attention pair; 0 indicates it is masked out
198
202
  allowed = torch.ones_like(logits, dtype=torch.float) # [B, H, T, T]
199
203
 
200
- # causal mask: attn_mask is usually an upper triangular matrix of -inf with shape [T, T]
201
- if attn_mask is not None:
202
- allowed = allowed * (attn_mask.view(1, 1, T, T) == 0).float()
203
- logits = logits + attn_mask.view(1, 1, T, T)
204
+ # causal mask: attention_mask is usually an upper triangular matrix of -inf with shape [T, T]
205
+ if attention_mask is not None:
206
+ allowed = allowed * (attention_mask.view(1, 1, T, T) == 0).float()
207
+ logits = logits + attention_mask.view(1, 1, T, T)
204
208
 
205
209
  # padding mask: key_padding_mask is usually [B, T], True = pad
206
210
  if key_padding_mask is not None:
@@ -211,12 +215,15 @@ class HSTUPointwiseAttention(nn.Module):
211
215
  logits = logits.masked_fill(valid == 0, float("-inf"))
212
216
 
213
217
  # Eq.(2): A(X)V(X) = φ2(QK^T + rab) V(X) / N
214
- attn = F.silu(logits) # [B, H, T, T]
218
+ # Note: F.silu(-inf) = nan, so we need to handle -inf values carefully
219
+ # Replace -inf with a very negative value before silu to avoid nan
220
+ logits_safe = logits.masked_fill(torch.isinf(logits) & (logits < 0), -1e9)
221
+ attention = F.silu(logits_safe) # [B, H, T, T]
215
222
  denom = allowed.sum(dim=-1, keepdim=True) # [B, H, T, 1]
216
223
  denom = denom.clamp(min=1.0)
217
224
 
218
- attn = attn / denom # [B, H, T, T]
219
- AV = torch.matmul(attn, Vh) # [B, H, T, d_head]
225
+ attention = attention / denom # [B, H, T, T]
226
+ AV = torch.matmul(attention, Vh) # [B, H, T, head_dim]
220
227
  AV = AV.transpose(1, 2).contiguous().view(B, T, D) # reshape back to [B, T, D]
221
228
  U_flat = Uh.transpose(1, 2).contiguous().view(B, T, D)
222
229
  y = self.out_proj(self.dropout(self.norm(AV) * U_flat)) # [B, T, D]
@@ -230,16 +237,20 @@ class HSTULayer(nn.Module):
230
237
 
231
238
  def __init__(
232
239
  self,
233
- d_model: int,
240
+ hidden_dim: int,
234
241
  num_heads: int,
235
242
  dropout: float = 0.1,
236
243
  use_rab_pos: bool = True,
237
244
  rab_num_buckets: int = 32,
238
245
  rab_max_distance: int = 128,
246
+ use_rms_norm: bool = False,
239
247
  ):
240
248
  super().__init__()
241
- self.attn = HSTUPointwiseAttention(
242
- d_model=d_model, num_heads=num_heads, dropout=dropout
249
+ self.attention = HSTUPointwiseAttention(
250
+ hidden_dim=hidden_dim,
251
+ num_heads=num_heads,
252
+ dropout=dropout,
253
+ use_rms_norm=use_rms_norm,
243
254
  )
244
255
  self.dropout = nn.Dropout(dropout)
245
256
  self.use_rab_pos = use_rab_pos
@@ -256,7 +267,7 @@ class HSTULayer(nn.Module):
256
267
  def forward(
257
268
  self,
258
269
  x: torch.Tensor,
259
- attn_mask: Optional[torch.Tensor] = None,
270
+ attention_mask: Optional[torch.Tensor] = None,
260
271
  key_padding_mask: Optional[torch.Tensor] = None,
261
272
  ) -> torch.Tensor:
262
273
  """
@@ -267,8 +278,11 @@ class HSTULayer(nn.Module):
267
278
  rab = None
268
279
  if self.use_rab_pos:
269
280
  rab = self.rel_pos_bias(seq_len=T, device=device) # [1, H, T, T]
270
- out = self.attn(
271
- x=x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, rab=rab
281
+ out = self.attention(
282
+ x=x,
283
+ attention_mask=attention_mask,
284
+ key_padding_mask=key_padding_mask,
285
+ rab=rab,
272
286
  )
273
287
  return x + self.dropout(out)
274
288
 
@@ -295,7 +309,8 @@ class HSTU(BaseModel):
295
309
  sequence_features: list[SequenceFeature],
296
310
  dense_features: Optional[list[DenseFeature]] = None,
297
311
  sparse_features: Optional[list[SparseFeature]] = None,
298
- d_model: Optional[int] = None,
312
+ item_history: str = "item_history",
313
+ hidden_dim: Optional[int] = None,
299
314
  num_heads: int = 8,
300
315
  num_layers: int = 4,
301
316
  max_seq_len: int = 200,
@@ -304,6 +319,8 @@ class HSTU(BaseModel):
304
319
  use_rab_pos: bool = True,
305
320
  rab_num_buckets: int = 32,
306
321
  rab_max_distance: int = 128,
322
+ # Normalization settings
323
+ use_rms_norm: bool = False,
307
324
  tie_embeddings: bool = True,
308
325
  target: Optional[list[str] | str] = None,
309
326
  task: str | list[str] | None = None,
@@ -324,22 +341,23 @@ class HSTU(BaseModel):
324
341
  "[HSTU Error] HSTU requires at least one SequenceFeature (user behavior history)."
325
342
  )
326
343
 
327
- # demo version: use the first SequenceFeature as the main sequence
328
- self.history_feature = sequence_features[0]
344
+ self.item_history_feature = select_features(
345
+ sequence_features, [item_history], "item_history"
346
+ )[0]
329
347
 
330
- hidden_dim = d_model or max(
331
- int(getattr(self.history_feature, "embedding_dim", 0) or 0), 32
348
+ self.hidden_dim = hidden_dim or max(
349
+ int(getattr(self.item_history_feature, "embedding_dim", 0) or 0), 32
332
350
  )
333
351
  # Make hidden_dim divisible by num_heads
334
- if hidden_dim % num_heads != 0:
335
- hidden_dim = num_heads * math.ceil(hidden_dim / num_heads)
352
+ if self.hidden_dim % num_heads != 0:
353
+ self.hidden_dim = num_heads * math.ceil(self.hidden_dim / num_heads)
336
354
 
337
355
  self.padding_idx = (
338
- self.history_feature.padding_idx
339
- if self.history_feature.padding_idx is not None
356
+ self.item_history_feature.padding_idx
357
+ if self.item_history_feature.padding_idx is not None
340
358
  else 0
341
359
  )
342
- self.vocab_size = self.history_feature.vocab_size
360
+ self.vocab_size = self.item_history_feature.vocab_size
343
361
  self.max_seq_len = max_seq_len
344
362
 
345
363
  super().__init__(
@@ -356,32 +374,51 @@ class HSTU(BaseModel):
356
374
  **kwargs,
357
375
  )
358
376
 
377
+ # Optional contextual encoders (user/item attributes, real-time context, etc.)
378
+ self.context_features = [
379
+ feat
380
+ for feat in self.all_features
381
+ if feat.name != self.item_history_feature.name
382
+ ]
383
+ self.context_embedding = (
384
+ EmbeddingLayer(self.context_features) if self.context_features else None
385
+ )
386
+ self.context_proj = (
387
+ nn.Linear(self.context_embedding.output_dim, self.hidden_dim)
388
+ if self.context_embedding is not None
389
+ else None
390
+ )
391
+ self.context_dropout = nn.Dropout(dropout) if self.context_embedding else None
392
+
359
393
  # token & position embedding (paper usually includes pos embedding / RAB in encoder)
360
394
  self.token_embedding = nn.Embedding(
361
395
  num_embeddings=self.vocab_size,
362
- embedding_dim=hidden_dim,
396
+ embedding_dim=self.hidden_dim,
363
397
  padding_idx=self.padding_idx,
364
398
  )
365
- self.position_embedding = nn.Embedding(max_seq_len, hidden_dim)
399
+ self.position_embedding = nn.Embedding(max_seq_len, self.hidden_dim)
366
400
  self.input_dropout = nn.Dropout(dropout)
367
401
 
368
402
  # HSTU layers
369
403
  self.layers = nn.ModuleList(
370
404
  [
371
405
  HSTULayer(
372
- d_model=hidden_dim,
406
+ hidden_dim=self.hidden_dim,
373
407
  num_heads=num_heads,
374
408
  dropout=dropout,
375
409
  use_rab_pos=use_rab_pos,
376
410
  rab_num_buckets=rab_num_buckets,
377
411
  rab_max_distance=rab_max_distance,
412
+ use_rms_norm=use_rms_norm,
378
413
  )
379
414
  for _ in range(num_layers)
380
415
  ]
381
416
  )
382
417
 
383
- self.final_norm = nn.LayerNorm(hidden_dim)
384
- self.lm_head = nn.Linear(hidden_dim, self.vocab_size, bias=False)
418
+ self.final_norm = (
419
+ RMSNorm(self.hidden_dim) if use_rms_norm else nn.LayerNorm(self.hidden_dim)
420
+ )
421
+ self.lm_head = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
385
422
  if tie_embeddings:
386
423
  self.lm_head.weight = self.token_embedding.weight
387
424
 
@@ -403,10 +440,11 @@ class HSTU(BaseModel):
403
440
  loss_params=loss_params,
404
441
  )
405
442
  self.register_regularization_weights(
406
- embedding_attr="token_embedding", include_modules=["layers", "lm_head"]
443
+ embedding_attr="token_embedding",
444
+ include_modules=["layers", "lm_head", "context_proj"],
407
445
  )
408
446
 
409
- def _build_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
447
+ def build_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
410
448
  """
411
449
  build causal mask of shape [T, T]: upper triangle is -inf, others are 0.
412
450
  This will be added to the logits to simulate causal structure.
@@ -417,14 +455,14 @@ class HSTU(BaseModel):
417
455
  self.causal_mask = mask
418
456
  return self.causal_mask[:seq_len, :seq_len]
419
457
 
420
- def _trim_sequence(self, seq: torch.Tensor) -> torch.Tensor:
458
+ def trim_sequence(self, seq: torch.Tensor) -> torch.Tensor:
421
459
  if seq.size(1) <= self.max_seq_len:
422
460
  return seq
423
461
  return seq[:, -self.max_seq_len :]
424
462
 
425
463
  def forward(self, x: dict[str, torch.Tensor]) -> torch.Tensor:
426
- seq = x[self.history_feature.name].long() # [B, T_raw]
427
- seq = self._trim_sequence(seq) # [B, T]
464
+ seq = x[self.item_history_feature.name].long() # [B, T_raw]
465
+ seq = self.trim_sequence(seq) # [B, T]
428
466
 
429
467
  B, T = seq.shape
430
468
  device = seq.device
@@ -436,20 +474,35 @@ class HSTU(BaseModel):
436
474
 
437
475
  # padding mask:True = pad
438
476
  padding_mask = seq.eq(self.padding_idx) # [B, T]
439
- attn_mask = self._build_causal_mask(seq_len=T, device=device) # [T, T]
477
+ attention_mask = self.build_causal_mask(seq_len=T, device=device) # [T, T]
440
478
 
441
479
  for layer in self.layers:
442
480
  hidden_states = layer(
443
- x=hidden_states, attn_mask=attn_mask, key_padding_mask=padding_mask
481
+ x=hidden_states,
482
+ attention_mask=attention_mask,
483
+ key_padding_mask=padding_mask,
444
484
  )
445
485
  hidden_states = self.final_norm(hidden_states) # [B, T, D]
446
486
 
447
487
  valid_lengths = (~padding_mask).sum(dim=1) # [B]
448
488
  last_index = (valid_lengths - 1).clamp(min=0)
489
+
490
+ # For sequences with no valid tokens, we use position 0's hidden state
491
+ # In production, these sequences should be filtered out before inference
449
492
  last_hidden = hidden_states[
450
493
  torch.arange(B, device=device), last_index
451
494
  ] # [B, D]
452
495
 
496
+ if self.context_embedding is not None and self.context_proj is not None:
497
+ context_repr = self.context_embedding(
498
+ x, self.context_features, squeeze_dim=True
499
+ ) # [B, D_ctx]
500
+ context_repr = self.context_proj(context_repr) # [B, D]
501
+ if self.context_dropout is not None:
502
+ context_repr = self.context_dropout(context_repr)
503
+ # fuse contextual signal into the autoregressive token summary
504
+ last_hidden = last_hidden + context_repr
505
+
453
506
  logits = self.lm_head(last_hidden) # [B, vocab_size]
454
507
  return logits
455
508