nextrec 0.4.5__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.5"
1
+ __version__ = "0.4.6"
nextrec/basic/features.py CHANGED
@@ -33,6 +33,8 @@ class SequenceFeature(BaseFeature):
33
33
  l1_reg: float = 0.0,
34
34
  l2_reg: float = 1e-5,
35
35
  trainable: bool = True,
36
+ pretrained_weight: torch.Tensor | None = None,
37
+ freeze_pretrained: bool = False,
36
38
  ):
37
39
  self.name = name
38
40
  self.vocab_size = vocab_size
@@ -47,6 +49,8 @@ class SequenceFeature(BaseFeature):
47
49
  self.l1_reg = l1_reg
48
50
  self.l2_reg = l2_reg
49
51
  self.trainable = trainable
52
+ self.pretrained_weight = pretrained_weight
53
+ self.freeze_pretrained = freeze_pretrained
50
54
 
51
55
 
52
56
  class SparseFeature(BaseFeature):
nextrec/basic/layers.py CHANGED
@@ -496,12 +496,18 @@ class HadamardInteractionLayer(nn.Module):
496
496
 
497
497
 
498
498
  class MultiHeadSelfAttention(nn.Module):
499
+ """
500
+ Multi-Head Self-Attention layer with Flash Attention support.
501
+ Uses PyTorch 2.0+ scaled_dot_product_attention when available for better performance.
502
+ """
503
+
499
504
  def __init__(
500
505
  self,
501
506
  embedding_dim: int,
502
507
  num_heads: int = 2,
503
508
  dropout: float = 0.0,
504
509
  use_residual: bool = True,
510
+ use_layer_norm: bool = False,
505
511
  ):
506
512
  super().__init__()
507
513
  if embedding_dim % num_heads != 0:
@@ -512,45 +518,100 @@ class MultiHeadSelfAttention(nn.Module):
512
518
  self.num_heads = num_heads
513
519
  self.head_dim = embedding_dim // num_heads
514
520
  self.use_residual = use_residual
521
+ self.dropout_rate = dropout
522
+
515
523
  self.W_Q = nn.Linear(embedding_dim, embedding_dim, bias=False)
516
524
  self.W_K = nn.Linear(embedding_dim, embedding_dim, bias=False)
517
525
  self.W_V = nn.Linear(embedding_dim, embedding_dim, bias=False)
526
+ self.W_O = nn.Linear(embedding_dim, embedding_dim, bias=False)
527
+
518
528
  if self.use_residual:
519
529
  self.W_Res = nn.Linear(embedding_dim, embedding_dim, bias=False)
530
+ if use_layer_norm:
531
+ self.layer_norm = nn.LayerNorm(embedding_dim)
532
+ else:
533
+ self.layer_norm = None
534
+
520
535
  self.dropout = nn.Dropout(dropout)
536
+ # Check if Flash Attention is available
537
+ self.use_flash_attention = hasattr(F, "scaled_dot_product_attention")
521
538
 
522
- def forward(self, x: torch.Tensor) -> torch.Tensor:
523
- batch_size, num_fields, _ = x.shape
524
- Q = self.W_Q(x) # [batch_size, num_fields, embedding_dim]
539
+ def forward(
540
+ self, x: torch.Tensor, attention_mask: torch.Tensor | None = None
541
+ ) -> torch.Tensor:
542
+ """
543
+ Args:
544
+ x: [batch_size, seq_len, embedding_dim]
545
+ attention_mask: [batch_size, seq_len] or [batch_size, seq_len, seq_len], boolean mask where True indicates valid positions
546
+ Returns:
547
+ output: [batch_size, seq_len, embedding_dim]
548
+ """
549
+ batch_size, seq_len, _ = x.shape
550
+ Q = self.W_Q(x) # [batch_size, seq_len, embedding_dim]
525
551
  K = self.W_K(x)
526
552
  V = self.W_V(x)
527
- # Split into multiple heads: [batch_size, num_heads, num_fields, head_dim]
528
- Q = Q.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
529
- 1, 2
530
- )
531
- K = K.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
532
- 1, 2
533
- )
534
- V = V.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(
535
- 1, 2
536
- )
537
- # Attention scores
538
- scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim**0.5)
539
- attention_weights = F.softmax(scores, dim=-1)
540
- attention_weights = self.dropout(attention_weights)
541
- attention_output = torch.matmul(
542
- attention_weights, V
543
- ) # [batch_size, num_heads, num_fields, head_dim]
553
+
554
+ # Split into multiple heads: [batch_size, num_heads, seq_len, head_dim]
555
+ Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
556
+ K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
557
+ V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
558
+
559
+ if self.use_flash_attention:
560
+ # Use PyTorch 2.0+ Flash Attention
561
+ if attention_mask is not None:
562
+ # Convert mask to [batch_size, 1, seq_len, seq_len] format
563
+ if attention_mask.dim() == 2:
564
+ # [B, L] -> [B, 1, 1, L]
565
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
566
+ elif attention_mask.dim() == 3:
567
+ # [B, L, L] -> [B, 1, L, L]
568
+ attention_mask = attention_mask.unsqueeze(1)
569
+ attention_output = F.scaled_dot_product_attention(
570
+ Q,
571
+ K,
572
+ V,
573
+ attn_mask=attention_mask,
574
+ dropout_p=self.dropout_rate if self.training else 0.0,
575
+ )
576
+ # Handle potential NaN values
577
+ attention_output = torch.nan_to_num(attention_output, nan=0.0)
578
+ else:
579
+ # Fallback to standard attention
580
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim**0.5)
581
+
582
+ if attention_mask is not None:
583
+ # Process mask for standard attention
584
+ if attention_mask.dim() == 2:
585
+ # [B, L] -> [B, 1, 1, L]
586
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
587
+ elif attention_mask.dim() == 3:
588
+ # [B, L, L] -> [B, 1, L, L]
589
+ attention_mask = attention_mask.unsqueeze(1)
590
+ scores = scores.masked_fill(~attention_mask, float("-1e9"))
591
+
592
+ attention_weights = F.softmax(scores, dim=-1)
593
+ attention_weights = self.dropout(attention_weights)
594
+ attention_output = torch.matmul(
595
+ attention_weights, V
596
+ ) # [batch_size, num_heads, seq_len, head_dim]
597
+
544
598
  # Concatenate heads
545
599
  attention_output = attention_output.transpose(1, 2).contiguous()
546
600
  attention_output = attention_output.view(
547
- batch_size, num_fields, self.embedding_dim
601
+ batch_size, seq_len, self.embedding_dim
548
602
  )
603
+
604
+ # Output projection
605
+ output = self.W_O(attention_output)
606
+
549
607
  # Residual connection
550
608
  if self.use_residual:
551
- output = attention_output + self.W_Res(x)
552
- else:
553
- output = attention_output
609
+ output = output + self.W_Res(x)
610
+
611
+ # Layer normalization
612
+ if self.layer_norm is not None:
613
+ output = self.layer_norm(output)
614
+
554
615
  output = F.relu(output)
555
616
  return output
556
617
 
@@ -653,3 +714,21 @@ class AttentionPoolingLayer(nn.Module):
653
714
  # Weighted sum over keys: (B, L, 1) * (B, L, D) -> (B, D)
654
715
  output = torch.sum(attention_weights * keys, dim=1)
655
716
  return output
717
+
718
+
719
+ class RMSNorm(torch.nn.Module):
720
+ """
721
+ Root Mean Square Layer Normalization.
722
+ Reference: https://arxiv.org/abs/1910.07467
723
+ """
724
+
725
+ def __init__(self, hidden_size: int, eps: float = 1e-6):
726
+ super().__init__()
727
+ self.eps = eps
728
+ self.weight = torch.nn.Parameter(torch.ones(hidden_size))
729
+
730
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
731
+ # RMS(x) = sqrt(mean(x^2) + eps)
732
+ variance = torch.mean(x**2, dim=-1, keepdim=True)
733
+ x_normalized = x * torch.rsqrt(variance + self.eps)
734
+ return self.weight * x_normalized
nextrec/basic/metrics.py CHANGED
@@ -44,6 +44,11 @@ TASK_DEFAULT_METRICS = {
44
44
  + [f"recall@{k}" for k in (5, 10, 20)]
45
45
  + [f"ndcg@{k}" for k in (5, 10, 20)]
46
46
  + [f"mrr@{k}" for k in (5, 10, 20)],
47
+ # generative/multiclass next-item prediction defaults
48
+ "multiclass": ["accuracy"]
49
+ + [f"hitrate@{k}" for k in (1, 5, 10)]
50
+ + [f"recall@{k}" for k in (1, 5, 10)]
51
+ + [f"mrr@{k}" for k in (1, 5, 10)],
47
52
  }
48
53
 
49
54
 
@@ -158,6 +163,51 @@ def group_indices_by_user(user_ids: np.ndarray, n_samples: int) -> list[np.ndarr
158
163
  return groups
159
164
 
160
165
 
166
+ def normalize_multiclass_inputs(
167
+ y_true: np.ndarray, y_pred: np.ndarray
168
+ ) -> tuple[np.ndarray, np.ndarray]:
169
+ """
170
+ Normalize multiclass inputs to consistent shapes.
171
+
172
+ y_true: [N] of class ids
173
+ y_pred: [N, C] of logits/probabilities
174
+ """
175
+ labels = np.asarray(y_true).reshape(-1)
176
+ scores = np.asarray(y_pred)
177
+ if scores.ndim == 1:
178
+ scores = scores.reshape(scores.shape[0], -1)
179
+ if scores.shape[0] != labels.shape[0]:
180
+ raise ValueError(
181
+ f"[Metric Warning] y_true length {labels.shape[0]} != y_pred batch {scores.shape[0]} for multiclass metrics."
182
+ )
183
+ return labels.astype(int), scores
184
+
185
+
186
+ def multiclass_topk_hit_rate(y_true: np.ndarray, y_pred: np.ndarray, k: int) -> float:
187
+ labels, scores = normalize_multiclass_inputs(y_true, y_pred)
188
+ if scores.shape[1] == 0:
189
+ return 0.0
190
+ k = min(k, scores.shape[1])
191
+ topk_idx = np.argpartition(-scores, kth=k - 1, axis=1)[:, :k]
192
+ hits = (topk_idx == labels[:, None]).any(axis=1)
193
+ return float(hits.mean()) if hits.size > 0 else 0.0
194
+
195
+
196
+ def multiclass_mrr_at_k(y_true: np.ndarray, y_pred: np.ndarray, k: int) -> float:
197
+ labels, scores = normalize_multiclass_inputs(y_true, y_pred)
198
+ if scores.shape[1] == 0:
199
+ return 0.0
200
+ k = min(k, scores.shape[1])
201
+ # full sort for stable ranks
202
+ topk_idx = np.argsort(-scores, axis=1)[:, :k]
203
+ ranks = np.full(labels.shape, fill_value=k + 1, dtype=np.float32)
204
+ for idx in range(k):
205
+ match = topk_idx[:, idx] == labels
206
+ ranks[match] = idx + 1
207
+ reciprocals = np.where(ranks <= k, 1.0 / ranks, 0.0)
208
+ return float(reciprocals.mean()) if reciprocals.size > 0 else 0.0
209
+
210
+
161
211
  def compute_precision_at_k(
162
212
  y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int
163
213
  ) -> float:
@@ -463,8 +513,28 @@ def compute_single_metric(
463
513
  ) -> float:
464
514
  """Compute a single metric given true and predicted values."""
465
515
  y_p_binary = (y_pred > 0.5).astype(int)
516
+ metric_lower = metric.lower()
517
+ is_multiclass = task_type == "multiclass" and y_pred.ndim >= 2
518
+ if is_multiclass:
519
+ # Dedicated path for multiclass logits (e.g., next-item prediction)
520
+ labels, scores = normalize_multiclass_inputs(y_true, y_pred)
521
+ if metric_lower in ("accuracy", "acc"):
522
+ preds = scores.argmax(axis=1)
523
+ return float((preds == labels).mean())
524
+ if metric_lower.startswith("hitrate@") or metric_lower.startswith("hr@"):
525
+ k_str = metric_lower.split("@")[1]
526
+ k = int(k_str)
527
+ return multiclass_topk_hit_rate(labels, scores, k)
528
+ if metric_lower.startswith("recall@"):
529
+ k = int(metric_lower.split("@")[1])
530
+ return multiclass_topk_hit_rate(labels, scores, k)
531
+ if metric_lower.startswith("mrr@"):
532
+ k = int(metric_lower.split("@")[1])
533
+ return multiclass_mrr_at_k(labels, scores, k)
534
+ # fall back to accuracy if unsupported metric is requested
535
+ preds = scores.argmax(axis=1)
536
+ return float((preds == labels).mean())
466
537
  try:
467
- metric_lower = metric.lower()
468
538
  if metric_lower.startswith("recall@"):
469
539
  k = int(metric_lower.split("@")[1])
470
540
  return compute_recall_at_k(y_true, y_pred, user_ids, k) # type: ignore
nextrec/basic/model.py CHANGED
@@ -126,11 +126,9 @@ class BaseModel(FeatureSet, nn.Module):
126
126
  self.session = create_session(session_id)
127
127
  self.session_path = self.session.root # pwd/session_id, path for this session
128
128
  self.checkpoint_path = os.path.join(
129
- self.session_path, self.model_name + "_checkpoint.model"
130
- ) # example: pwd/session_id/DeepFM_checkpoint.model
131
- self.best_path = os.path.join(
132
- self.session_path, self.model_name + "_best.model"
133
- )
129
+ self.session_path, self.model_name + "_checkpoint.pt"
130
+ ) # example: pwd/session_id/DeepFM_checkpoint.pt
131
+ self.best_path = os.path.join(self.session_path, self.model_name + "_best.pt")
134
132
  self.features_config_path = os.path.join(
135
133
  self.session_path, "features_config.pkl"
136
134
  )
@@ -1563,7 +1561,7 @@ class BaseModel(FeatureSet, nn.Module):
1563
1561
  path=save_path,
1564
1562
  default_dir=self.session_path,
1565
1563
  default_name=self.model_name,
1566
- suffix=".model",
1564
+ suffix=".pt",
1567
1565
  add_timestamp=add_timestamp,
1568
1566
  )
1569
1567
  model_path = Path(target_path)
@@ -1603,16 +1601,16 @@ class BaseModel(FeatureSet, nn.Module):
1603
1601
  self.to(self.device)
1604
1602
  base_path = Path(save_path)
1605
1603
  if base_path.is_dir():
1606
- model_files = sorted(base_path.glob("*.model"))
1604
+ model_files = sorted(base_path.glob("*.pt"))
1607
1605
  if not model_files:
1608
1606
  raise FileNotFoundError(
1609
- f"[BaseModel-load-model Error] No *.model file found in directory: {base_path}"
1607
+ f"[BaseModel-load-model Error] No *.pt file found in directory: {base_path}"
1610
1608
  )
1611
1609
  model_path = model_files[-1]
1612
1610
  config_dir = base_path
1613
1611
  else:
1614
1612
  model_path = (
1615
- base_path.with_suffix(".model") if base_path.suffix == "" else base_path
1613
+ base_path.with_suffix(".pt") if base_path.suffix == "" else base_path
1616
1614
  )
1617
1615
  config_dir = model_path.parent
1618
1616
  if not model_path.exists():
@@ -1665,21 +1663,21 @@ class BaseModel(FeatureSet, nn.Module):
1665
1663
  ) -> "BaseModel":
1666
1664
  """
1667
1665
  Load a model from a checkpoint path. The checkpoint path should contain:
1668
- a .model file and a features_config.pkl file.
1666
+ a .pt file and a features_config.pkl file.
1669
1667
  """
1670
1668
  base_path = Path(checkpoint_path)
1671
1669
  verbose = kwargs.pop("verbose", True)
1672
1670
  if base_path.is_dir():
1673
- model_candidates = sorted(base_path.glob("*.model"))
1671
+ model_candidates = sorted(base_path.glob("*.pt"))
1674
1672
  if not model_candidates:
1675
1673
  raise FileNotFoundError(
1676
- f"[BaseModel-from-checkpoint Error] No *.model file found under: {base_path}"
1674
+ f"[BaseModel-from-checkpoint Error] No *.pt file found under: {base_path}"
1677
1675
  )
1678
1676
  model_file = model_candidates[-1]
1679
1677
  config_dir = base_path
1680
1678
  else:
1681
1679
  model_file = (
1682
- base_path.with_suffix(".model") if base_path.suffix == "" else base_path
1680
+ base_path.with_suffix(".pt") if base_path.suffix == "" else base_path
1683
1681
  )
1684
1682
  config_dir = model_file.parent
1685
1683
  features_config_path = config_dir / "features_config.pkl"
@@ -25,9 +25,7 @@ def get_column_data(data: dict | pd.DataFrame, name: str):
25
25
  raise KeyError(f"Unsupported data type for extracting column {name}")
26
26
 
27
27
 
28
- def split_dict_random(
29
- data_dict: dict, test_size: float = 0.2, random_state: int | None = None
30
- ):
28
+ def split_dict_random(data_dict, test_size=0.2, random_state=None):
31
29
 
32
30
  lengths = [len(v) for v in data_dict.values()]
33
31
  if len(set(lengths)) != 1:
@@ -0,0 +1,16 @@
1
+ """
2
+ Generative Recommendation Models
3
+
4
+ This module contains generative models for recommendation tasks.
5
+ """
6
+
7
+ from nextrec.models.generative.hstu import HSTU
8
+ from nextrec.models.generative.rqvae import (
9
+ RQVAE,
10
+ RQ,
11
+ VQEmbedding,
12
+ BalancedKmeans,
13
+ kmeans,
14
+ )
15
+
16
+ __all__ = ["HSTU", "RQVAE", "RQ", "VQEmbedding", "BalancedKmeans", "kmeans"]