nextrec 0.4.5__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/features.py +4 -0
- nextrec/basic/layers.py +103 -24
- nextrec/basic/metrics.py +71 -1
- nextrec/basic/model.py +11 -13
- nextrec/data/data_processing.py +1 -3
- nextrec/models/generative/__init__.py +16 -0
- nextrec/models/generative/hstu.py +110 -57
- nextrec/models/generative/rqvae.py +826 -0
- nextrec/models/ranking/masknet.py +1 -1
- nextrec/utils/config.py +38 -1
- nextrec/utils/embedding.py +28 -0
- nextrec/utils/initializer.py +4 -4
- nextrec/utils/synthetic_data.py +19 -0
- nextrec-0.4.6.dist-info/METADATA +371 -0
- {nextrec-0.4.5.dist-info → nextrec-0.4.6.dist-info}/RECORD +19 -18
- nextrec-0.4.5.dist-info/METADATA +0 -357
- {nextrec-0.4.5.dist-info → nextrec-0.4.6.dist-info}/WHEEL +0 -0
- {nextrec-0.4.5.dist-info → nextrec-0.4.6.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.5.dist-info → nextrec-0.4.6.dist-info}/licenses/LICENSE +0 -0
nextrec/__version__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.4.
|
|
1
|
+
__version__ = "0.4.6"
|
nextrec/basic/features.py
CHANGED
|
@@ -33,6 +33,8 @@ class SequenceFeature(BaseFeature):
|
|
|
33
33
|
l1_reg: float = 0.0,
|
|
34
34
|
l2_reg: float = 1e-5,
|
|
35
35
|
trainable: bool = True,
|
|
36
|
+
pretrained_weight: torch.Tensor | None = None,
|
|
37
|
+
freeze_pretrained: bool = False,
|
|
36
38
|
):
|
|
37
39
|
self.name = name
|
|
38
40
|
self.vocab_size = vocab_size
|
|
@@ -47,6 +49,8 @@ class SequenceFeature(BaseFeature):
|
|
|
47
49
|
self.l1_reg = l1_reg
|
|
48
50
|
self.l2_reg = l2_reg
|
|
49
51
|
self.trainable = trainable
|
|
52
|
+
self.pretrained_weight = pretrained_weight
|
|
53
|
+
self.freeze_pretrained = freeze_pretrained
|
|
50
54
|
|
|
51
55
|
|
|
52
56
|
class SparseFeature(BaseFeature):
|
nextrec/basic/layers.py
CHANGED
|
@@ -496,12 +496,18 @@ class HadamardInteractionLayer(nn.Module):
|
|
|
496
496
|
|
|
497
497
|
|
|
498
498
|
class MultiHeadSelfAttention(nn.Module):
|
|
499
|
+
"""
|
|
500
|
+
Multi-Head Self-Attention layer with Flash Attention support.
|
|
501
|
+
Uses PyTorch 2.0+ scaled_dot_product_attention when available for better performance.
|
|
502
|
+
"""
|
|
503
|
+
|
|
499
504
|
def __init__(
|
|
500
505
|
self,
|
|
501
506
|
embedding_dim: int,
|
|
502
507
|
num_heads: int = 2,
|
|
503
508
|
dropout: float = 0.0,
|
|
504
509
|
use_residual: bool = True,
|
|
510
|
+
use_layer_norm: bool = False,
|
|
505
511
|
):
|
|
506
512
|
super().__init__()
|
|
507
513
|
if embedding_dim % num_heads != 0:
|
|
@@ -512,45 +518,100 @@ class MultiHeadSelfAttention(nn.Module):
|
|
|
512
518
|
self.num_heads = num_heads
|
|
513
519
|
self.head_dim = embedding_dim // num_heads
|
|
514
520
|
self.use_residual = use_residual
|
|
521
|
+
self.dropout_rate = dropout
|
|
522
|
+
|
|
515
523
|
self.W_Q = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
516
524
|
self.W_K = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
517
525
|
self.W_V = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
526
|
+
self.W_O = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
527
|
+
|
|
518
528
|
if self.use_residual:
|
|
519
529
|
self.W_Res = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
530
|
+
if use_layer_norm:
|
|
531
|
+
self.layer_norm = nn.LayerNorm(embedding_dim)
|
|
532
|
+
else:
|
|
533
|
+
self.layer_norm = None
|
|
534
|
+
|
|
520
535
|
self.dropout = nn.Dropout(dropout)
|
|
536
|
+
# Check if Flash Attention is available
|
|
537
|
+
self.use_flash_attention = hasattr(F, "scaled_dot_product_attention")
|
|
521
538
|
|
|
522
|
-
def forward(
|
|
523
|
-
|
|
524
|
-
|
|
539
|
+
def forward(
|
|
540
|
+
self, x: torch.Tensor, attention_mask: torch.Tensor | None = None
|
|
541
|
+
) -> torch.Tensor:
|
|
542
|
+
"""
|
|
543
|
+
Args:
|
|
544
|
+
x: [batch_size, seq_len, embedding_dim]
|
|
545
|
+
attention_mask: [batch_size, seq_len] or [batch_size, seq_len, seq_len], boolean mask where True indicates valid positions
|
|
546
|
+
Returns:
|
|
547
|
+
output: [batch_size, seq_len, embedding_dim]
|
|
548
|
+
"""
|
|
549
|
+
batch_size, seq_len, _ = x.shape
|
|
550
|
+
Q = self.W_Q(x) # [batch_size, seq_len, embedding_dim]
|
|
525
551
|
K = self.W_K(x)
|
|
526
552
|
V = self.W_V(x)
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
)
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
553
|
+
|
|
554
|
+
# Split into multiple heads: [batch_size, num_heads, seq_len, head_dim]
|
|
555
|
+
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
556
|
+
K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
557
|
+
V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
558
|
+
|
|
559
|
+
if self.use_flash_attention:
|
|
560
|
+
# Use PyTorch 2.0+ Flash Attention
|
|
561
|
+
if attention_mask is not None:
|
|
562
|
+
# Convert mask to [batch_size, 1, seq_len, seq_len] format
|
|
563
|
+
if attention_mask.dim() == 2:
|
|
564
|
+
# [B, L] -> [B, 1, 1, L]
|
|
565
|
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
566
|
+
elif attention_mask.dim() == 3:
|
|
567
|
+
# [B, L, L] -> [B, 1, L, L]
|
|
568
|
+
attention_mask = attention_mask.unsqueeze(1)
|
|
569
|
+
attention_output = F.scaled_dot_product_attention(
|
|
570
|
+
Q,
|
|
571
|
+
K,
|
|
572
|
+
V,
|
|
573
|
+
attn_mask=attention_mask,
|
|
574
|
+
dropout_p=self.dropout_rate if self.training else 0.0,
|
|
575
|
+
)
|
|
576
|
+
# Handle potential NaN values
|
|
577
|
+
attention_output = torch.nan_to_num(attention_output, nan=0.0)
|
|
578
|
+
else:
|
|
579
|
+
# Fallback to standard attention
|
|
580
|
+
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim**0.5)
|
|
581
|
+
|
|
582
|
+
if attention_mask is not None:
|
|
583
|
+
# Process mask for standard attention
|
|
584
|
+
if attention_mask.dim() == 2:
|
|
585
|
+
# [B, L] -> [B, 1, 1, L]
|
|
586
|
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
587
|
+
elif attention_mask.dim() == 3:
|
|
588
|
+
# [B, L, L] -> [B, 1, L, L]
|
|
589
|
+
attention_mask = attention_mask.unsqueeze(1)
|
|
590
|
+
scores = scores.masked_fill(~attention_mask, float("-1e9"))
|
|
591
|
+
|
|
592
|
+
attention_weights = F.softmax(scores, dim=-1)
|
|
593
|
+
attention_weights = self.dropout(attention_weights)
|
|
594
|
+
attention_output = torch.matmul(
|
|
595
|
+
attention_weights, V
|
|
596
|
+
) # [batch_size, num_heads, seq_len, head_dim]
|
|
597
|
+
|
|
544
598
|
# Concatenate heads
|
|
545
599
|
attention_output = attention_output.transpose(1, 2).contiguous()
|
|
546
600
|
attention_output = attention_output.view(
|
|
547
|
-
batch_size,
|
|
601
|
+
batch_size, seq_len, self.embedding_dim
|
|
548
602
|
)
|
|
603
|
+
|
|
604
|
+
# Output projection
|
|
605
|
+
output = self.W_O(attention_output)
|
|
606
|
+
|
|
549
607
|
# Residual connection
|
|
550
608
|
if self.use_residual:
|
|
551
|
-
output =
|
|
552
|
-
|
|
553
|
-
|
|
609
|
+
output = output + self.W_Res(x)
|
|
610
|
+
|
|
611
|
+
# Layer normalization
|
|
612
|
+
if self.layer_norm is not None:
|
|
613
|
+
output = self.layer_norm(output)
|
|
614
|
+
|
|
554
615
|
output = F.relu(output)
|
|
555
616
|
return output
|
|
556
617
|
|
|
@@ -653,3 +714,21 @@ class AttentionPoolingLayer(nn.Module):
|
|
|
653
714
|
# Weighted sum over keys: (B, L, 1) * (B, L, D) -> (B, D)
|
|
654
715
|
output = torch.sum(attention_weights * keys, dim=1)
|
|
655
716
|
return output
|
|
717
|
+
|
|
718
|
+
|
|
719
|
+
class RMSNorm(torch.nn.Module):
|
|
720
|
+
"""
|
|
721
|
+
Root Mean Square Layer Normalization.
|
|
722
|
+
Reference: https://arxiv.org/abs/1910.07467
|
|
723
|
+
"""
|
|
724
|
+
|
|
725
|
+
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
|
726
|
+
super().__init__()
|
|
727
|
+
self.eps = eps
|
|
728
|
+
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
|
729
|
+
|
|
730
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
731
|
+
# RMS(x) = sqrt(mean(x^2) + eps)
|
|
732
|
+
variance = torch.mean(x**2, dim=-1, keepdim=True)
|
|
733
|
+
x_normalized = x * torch.rsqrt(variance + self.eps)
|
|
734
|
+
return self.weight * x_normalized
|
nextrec/basic/metrics.py
CHANGED
|
@@ -44,6 +44,11 @@ TASK_DEFAULT_METRICS = {
|
|
|
44
44
|
+ [f"recall@{k}" for k in (5, 10, 20)]
|
|
45
45
|
+ [f"ndcg@{k}" for k in (5, 10, 20)]
|
|
46
46
|
+ [f"mrr@{k}" for k in (5, 10, 20)],
|
|
47
|
+
# generative/multiclass next-item prediction defaults
|
|
48
|
+
"multiclass": ["accuracy"]
|
|
49
|
+
+ [f"hitrate@{k}" for k in (1, 5, 10)]
|
|
50
|
+
+ [f"recall@{k}" for k in (1, 5, 10)]
|
|
51
|
+
+ [f"mrr@{k}" for k in (1, 5, 10)],
|
|
47
52
|
}
|
|
48
53
|
|
|
49
54
|
|
|
@@ -158,6 +163,51 @@ def group_indices_by_user(user_ids: np.ndarray, n_samples: int) -> list[np.ndarr
|
|
|
158
163
|
return groups
|
|
159
164
|
|
|
160
165
|
|
|
166
|
+
def normalize_multiclass_inputs(
|
|
167
|
+
y_true: np.ndarray, y_pred: np.ndarray
|
|
168
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
169
|
+
"""
|
|
170
|
+
Normalize multiclass inputs to consistent shapes.
|
|
171
|
+
|
|
172
|
+
y_true: [N] of class ids
|
|
173
|
+
y_pred: [N, C] of logits/probabilities
|
|
174
|
+
"""
|
|
175
|
+
labels = np.asarray(y_true).reshape(-1)
|
|
176
|
+
scores = np.asarray(y_pred)
|
|
177
|
+
if scores.ndim == 1:
|
|
178
|
+
scores = scores.reshape(scores.shape[0], -1)
|
|
179
|
+
if scores.shape[0] != labels.shape[0]:
|
|
180
|
+
raise ValueError(
|
|
181
|
+
f"[Metric Warning] y_true length {labels.shape[0]} != y_pred batch {scores.shape[0]} for multiclass metrics."
|
|
182
|
+
)
|
|
183
|
+
return labels.astype(int), scores
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def multiclass_topk_hit_rate(y_true: np.ndarray, y_pred: np.ndarray, k: int) -> float:
|
|
187
|
+
labels, scores = normalize_multiclass_inputs(y_true, y_pred)
|
|
188
|
+
if scores.shape[1] == 0:
|
|
189
|
+
return 0.0
|
|
190
|
+
k = min(k, scores.shape[1])
|
|
191
|
+
topk_idx = np.argpartition(-scores, kth=k - 1, axis=1)[:, :k]
|
|
192
|
+
hits = (topk_idx == labels[:, None]).any(axis=1)
|
|
193
|
+
return float(hits.mean()) if hits.size > 0 else 0.0
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def multiclass_mrr_at_k(y_true: np.ndarray, y_pred: np.ndarray, k: int) -> float:
|
|
197
|
+
labels, scores = normalize_multiclass_inputs(y_true, y_pred)
|
|
198
|
+
if scores.shape[1] == 0:
|
|
199
|
+
return 0.0
|
|
200
|
+
k = min(k, scores.shape[1])
|
|
201
|
+
# full sort for stable ranks
|
|
202
|
+
topk_idx = np.argsort(-scores, axis=1)[:, :k]
|
|
203
|
+
ranks = np.full(labels.shape, fill_value=k + 1, dtype=np.float32)
|
|
204
|
+
for idx in range(k):
|
|
205
|
+
match = topk_idx[:, idx] == labels
|
|
206
|
+
ranks[match] = idx + 1
|
|
207
|
+
reciprocals = np.where(ranks <= k, 1.0 / ranks, 0.0)
|
|
208
|
+
return float(reciprocals.mean()) if reciprocals.size > 0 else 0.0
|
|
209
|
+
|
|
210
|
+
|
|
161
211
|
def compute_precision_at_k(
|
|
162
212
|
y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int
|
|
163
213
|
) -> float:
|
|
@@ -463,8 +513,28 @@ def compute_single_metric(
|
|
|
463
513
|
) -> float:
|
|
464
514
|
"""Compute a single metric given true and predicted values."""
|
|
465
515
|
y_p_binary = (y_pred > 0.5).astype(int)
|
|
516
|
+
metric_lower = metric.lower()
|
|
517
|
+
is_multiclass = task_type == "multiclass" and y_pred.ndim >= 2
|
|
518
|
+
if is_multiclass:
|
|
519
|
+
# Dedicated path for multiclass logits (e.g., next-item prediction)
|
|
520
|
+
labels, scores = normalize_multiclass_inputs(y_true, y_pred)
|
|
521
|
+
if metric_lower in ("accuracy", "acc"):
|
|
522
|
+
preds = scores.argmax(axis=1)
|
|
523
|
+
return float((preds == labels).mean())
|
|
524
|
+
if metric_lower.startswith("hitrate@") or metric_lower.startswith("hr@"):
|
|
525
|
+
k_str = metric_lower.split("@")[1]
|
|
526
|
+
k = int(k_str)
|
|
527
|
+
return multiclass_topk_hit_rate(labels, scores, k)
|
|
528
|
+
if metric_lower.startswith("recall@"):
|
|
529
|
+
k = int(metric_lower.split("@")[1])
|
|
530
|
+
return multiclass_topk_hit_rate(labels, scores, k)
|
|
531
|
+
if metric_lower.startswith("mrr@"):
|
|
532
|
+
k = int(metric_lower.split("@")[1])
|
|
533
|
+
return multiclass_mrr_at_k(labels, scores, k)
|
|
534
|
+
# fall back to accuracy if unsupported metric is requested
|
|
535
|
+
preds = scores.argmax(axis=1)
|
|
536
|
+
return float((preds == labels).mean())
|
|
466
537
|
try:
|
|
467
|
-
metric_lower = metric.lower()
|
|
468
538
|
if metric_lower.startswith("recall@"):
|
|
469
539
|
k = int(metric_lower.split("@")[1])
|
|
470
540
|
return compute_recall_at_k(y_true, y_pred, user_ids, k) # type: ignore
|
nextrec/basic/model.py
CHANGED
|
@@ -126,11 +126,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
126
126
|
self.session = create_session(session_id)
|
|
127
127
|
self.session_path = self.session.root # pwd/session_id, path for this session
|
|
128
128
|
self.checkpoint_path = os.path.join(
|
|
129
|
-
self.session_path, self.model_name + "_checkpoint.
|
|
130
|
-
) # example: pwd/session_id/DeepFM_checkpoint.
|
|
131
|
-
self.best_path = os.path.join(
|
|
132
|
-
self.session_path, self.model_name + "_best.model"
|
|
133
|
-
)
|
|
129
|
+
self.session_path, self.model_name + "_checkpoint.pt"
|
|
130
|
+
) # example: pwd/session_id/DeepFM_checkpoint.pt
|
|
131
|
+
self.best_path = os.path.join(self.session_path, self.model_name + "_best.pt")
|
|
134
132
|
self.features_config_path = os.path.join(
|
|
135
133
|
self.session_path, "features_config.pkl"
|
|
136
134
|
)
|
|
@@ -1563,7 +1561,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1563
1561
|
path=save_path,
|
|
1564
1562
|
default_dir=self.session_path,
|
|
1565
1563
|
default_name=self.model_name,
|
|
1566
|
-
suffix=".
|
|
1564
|
+
suffix=".pt",
|
|
1567
1565
|
add_timestamp=add_timestamp,
|
|
1568
1566
|
)
|
|
1569
1567
|
model_path = Path(target_path)
|
|
@@ -1603,16 +1601,16 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1603
1601
|
self.to(self.device)
|
|
1604
1602
|
base_path = Path(save_path)
|
|
1605
1603
|
if base_path.is_dir():
|
|
1606
|
-
model_files = sorted(base_path.glob("*.
|
|
1604
|
+
model_files = sorted(base_path.glob("*.pt"))
|
|
1607
1605
|
if not model_files:
|
|
1608
1606
|
raise FileNotFoundError(
|
|
1609
|
-
f"[BaseModel-load-model Error] No *.
|
|
1607
|
+
f"[BaseModel-load-model Error] No *.pt file found in directory: {base_path}"
|
|
1610
1608
|
)
|
|
1611
1609
|
model_path = model_files[-1]
|
|
1612
1610
|
config_dir = base_path
|
|
1613
1611
|
else:
|
|
1614
1612
|
model_path = (
|
|
1615
|
-
base_path.with_suffix(".
|
|
1613
|
+
base_path.with_suffix(".pt") if base_path.suffix == "" else base_path
|
|
1616
1614
|
)
|
|
1617
1615
|
config_dir = model_path.parent
|
|
1618
1616
|
if not model_path.exists():
|
|
@@ -1665,21 +1663,21 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1665
1663
|
) -> "BaseModel":
|
|
1666
1664
|
"""
|
|
1667
1665
|
Load a model from a checkpoint path. The checkpoint path should contain:
|
|
1668
|
-
a .
|
|
1666
|
+
a .pt file and a features_config.pkl file.
|
|
1669
1667
|
"""
|
|
1670
1668
|
base_path = Path(checkpoint_path)
|
|
1671
1669
|
verbose = kwargs.pop("verbose", True)
|
|
1672
1670
|
if base_path.is_dir():
|
|
1673
|
-
model_candidates = sorted(base_path.glob("*.
|
|
1671
|
+
model_candidates = sorted(base_path.glob("*.pt"))
|
|
1674
1672
|
if not model_candidates:
|
|
1675
1673
|
raise FileNotFoundError(
|
|
1676
|
-
f"[BaseModel-from-checkpoint Error] No *.
|
|
1674
|
+
f"[BaseModel-from-checkpoint Error] No *.pt file found under: {base_path}"
|
|
1677
1675
|
)
|
|
1678
1676
|
model_file = model_candidates[-1]
|
|
1679
1677
|
config_dir = base_path
|
|
1680
1678
|
else:
|
|
1681
1679
|
model_file = (
|
|
1682
|
-
base_path.with_suffix(".
|
|
1680
|
+
base_path.with_suffix(".pt") if base_path.suffix == "" else base_path
|
|
1683
1681
|
)
|
|
1684
1682
|
config_dir = model_file.parent
|
|
1685
1683
|
features_config_path = config_dir / "features_config.pkl"
|
nextrec/data/data_processing.py
CHANGED
|
@@ -25,9 +25,7 @@ def get_column_data(data: dict | pd.DataFrame, name: str):
|
|
|
25
25
|
raise KeyError(f"Unsupported data type for extracting column {name}")
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
def split_dict_random(
|
|
29
|
-
data_dict: dict, test_size: float = 0.2, random_state: int | None = None
|
|
30
|
-
):
|
|
28
|
+
def split_dict_random(data_dict, test_size=0.2, random_state=None):
|
|
31
29
|
|
|
32
30
|
lengths = [len(v) for v in data_dict.values()]
|
|
33
31
|
if len(set(lengths)) != 1:
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Generative Recommendation Models
|
|
3
|
+
|
|
4
|
+
This module contains generative models for recommendation tasks.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from nextrec.models.generative.hstu import HSTU
|
|
8
|
+
from nextrec.models.generative.rqvae import (
|
|
9
|
+
RQVAE,
|
|
10
|
+
RQ,
|
|
11
|
+
VQEmbedding,
|
|
12
|
+
BalancedKmeans,
|
|
13
|
+
kmeans,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
__all__ = ["HSTU", "RQVAE", "RQ", "VQEmbedding", "BalancedKmeans", "kmeans"]
|