nextrec 0.4.5__py3-none-any.whl → 0.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/callback.py +399 -21
- nextrec/basic/features.py +4 -0
- nextrec/basic/layers.py +103 -24
- nextrec/basic/metrics.py +71 -1
- nextrec/basic/model.py +285 -186
- nextrec/data/data_processing.py +1 -3
- nextrec/loss/loss_utils.py +73 -4
- nextrec/models/generative/__init__.py +16 -0
- nextrec/models/generative/hstu.py +110 -57
- nextrec/models/generative/rqvae.py +826 -0
- nextrec/models/match/dssm.py +5 -4
- nextrec/models/match/dssm_v2.py +4 -3
- nextrec/models/match/mind.py +5 -4
- nextrec/models/match/sdm.py +5 -4
- nextrec/models/match/youtube_dnn.py +5 -4
- nextrec/models/ranking/masknet.py +1 -1
- nextrec/utils/config.py +38 -1
- nextrec/utils/embedding.py +28 -0
- nextrec/utils/initializer.py +4 -4
- nextrec/utils/synthetic_data.py +19 -0
- nextrec-0.4.7.dist-info/METADATA +376 -0
- {nextrec-0.4.5.dist-info → nextrec-0.4.7.dist-info}/RECORD +26 -25
- nextrec-0.4.5.dist-info/METADATA +0 -357
- {nextrec-0.4.5.dist-info → nextrec-0.4.7.dist-info}/WHEEL +0 -0
- {nextrec-0.4.5.dist-info → nextrec-0.4.7.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.5.dist-info → nextrec-0.4.7.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/metrics.py
CHANGED
|
@@ -44,6 +44,11 @@ TASK_DEFAULT_METRICS = {
|
|
|
44
44
|
+ [f"recall@{k}" for k in (5, 10, 20)]
|
|
45
45
|
+ [f"ndcg@{k}" for k in (5, 10, 20)]
|
|
46
46
|
+ [f"mrr@{k}" for k in (5, 10, 20)],
|
|
47
|
+
# generative/multiclass next-item prediction defaults
|
|
48
|
+
"multiclass": ["accuracy"]
|
|
49
|
+
+ [f"hitrate@{k}" for k in (1, 5, 10)]
|
|
50
|
+
+ [f"recall@{k}" for k in (1, 5, 10)]
|
|
51
|
+
+ [f"mrr@{k}" for k in (1, 5, 10)],
|
|
47
52
|
}
|
|
48
53
|
|
|
49
54
|
|
|
@@ -158,6 +163,51 @@ def group_indices_by_user(user_ids: np.ndarray, n_samples: int) -> list[np.ndarr
|
|
|
158
163
|
return groups
|
|
159
164
|
|
|
160
165
|
|
|
166
|
+
def normalize_multiclass_inputs(
|
|
167
|
+
y_true: np.ndarray, y_pred: np.ndarray
|
|
168
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
169
|
+
"""
|
|
170
|
+
Normalize multiclass inputs to consistent shapes.
|
|
171
|
+
|
|
172
|
+
y_true: [N] of class ids
|
|
173
|
+
y_pred: [N, C] of logits/probabilities
|
|
174
|
+
"""
|
|
175
|
+
labels = np.asarray(y_true).reshape(-1)
|
|
176
|
+
scores = np.asarray(y_pred)
|
|
177
|
+
if scores.ndim == 1:
|
|
178
|
+
scores = scores.reshape(scores.shape[0], -1)
|
|
179
|
+
if scores.shape[0] != labels.shape[0]:
|
|
180
|
+
raise ValueError(
|
|
181
|
+
f"[Metric Warning] y_true length {labels.shape[0]} != y_pred batch {scores.shape[0]} for multiclass metrics."
|
|
182
|
+
)
|
|
183
|
+
return labels.astype(int), scores
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def multiclass_topk_hit_rate(y_true: np.ndarray, y_pred: np.ndarray, k: int) -> float:
|
|
187
|
+
labels, scores = normalize_multiclass_inputs(y_true, y_pred)
|
|
188
|
+
if scores.shape[1] == 0:
|
|
189
|
+
return 0.0
|
|
190
|
+
k = min(k, scores.shape[1])
|
|
191
|
+
topk_idx = np.argpartition(-scores, kth=k - 1, axis=1)[:, :k]
|
|
192
|
+
hits = (topk_idx == labels[:, None]).any(axis=1)
|
|
193
|
+
return float(hits.mean()) if hits.size > 0 else 0.0
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def multiclass_mrr_at_k(y_true: np.ndarray, y_pred: np.ndarray, k: int) -> float:
|
|
197
|
+
labels, scores = normalize_multiclass_inputs(y_true, y_pred)
|
|
198
|
+
if scores.shape[1] == 0:
|
|
199
|
+
return 0.0
|
|
200
|
+
k = min(k, scores.shape[1])
|
|
201
|
+
# full sort for stable ranks
|
|
202
|
+
topk_idx = np.argsort(-scores, axis=1)[:, :k]
|
|
203
|
+
ranks = np.full(labels.shape, fill_value=k + 1, dtype=np.float32)
|
|
204
|
+
for idx in range(k):
|
|
205
|
+
match = topk_idx[:, idx] == labels
|
|
206
|
+
ranks[match] = idx + 1
|
|
207
|
+
reciprocals = np.where(ranks <= k, 1.0 / ranks, 0.0)
|
|
208
|
+
return float(reciprocals.mean()) if reciprocals.size > 0 else 0.0
|
|
209
|
+
|
|
210
|
+
|
|
161
211
|
def compute_precision_at_k(
|
|
162
212
|
y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int
|
|
163
213
|
) -> float:
|
|
@@ -463,8 +513,28 @@ def compute_single_metric(
|
|
|
463
513
|
) -> float:
|
|
464
514
|
"""Compute a single metric given true and predicted values."""
|
|
465
515
|
y_p_binary = (y_pred > 0.5).astype(int)
|
|
516
|
+
metric_lower = metric.lower()
|
|
517
|
+
is_multiclass = task_type == "multiclass" and y_pred.ndim >= 2
|
|
518
|
+
if is_multiclass:
|
|
519
|
+
# Dedicated path for multiclass logits (e.g., next-item prediction)
|
|
520
|
+
labels, scores = normalize_multiclass_inputs(y_true, y_pred)
|
|
521
|
+
if metric_lower in ("accuracy", "acc"):
|
|
522
|
+
preds = scores.argmax(axis=1)
|
|
523
|
+
return float((preds == labels).mean())
|
|
524
|
+
if metric_lower.startswith("hitrate@") or metric_lower.startswith("hr@"):
|
|
525
|
+
k_str = metric_lower.split("@")[1]
|
|
526
|
+
k = int(k_str)
|
|
527
|
+
return multiclass_topk_hit_rate(labels, scores, k)
|
|
528
|
+
if metric_lower.startswith("recall@"):
|
|
529
|
+
k = int(metric_lower.split("@")[1])
|
|
530
|
+
return multiclass_topk_hit_rate(labels, scores, k)
|
|
531
|
+
if metric_lower.startswith("mrr@"):
|
|
532
|
+
k = int(metric_lower.split("@")[1])
|
|
533
|
+
return multiclass_mrr_at_k(labels, scores, k)
|
|
534
|
+
# fall back to accuracy if unsupported metric is requested
|
|
535
|
+
preds = scores.argmax(axis=1)
|
|
536
|
+
return float((preds == labels).mean())
|
|
466
537
|
try:
|
|
467
|
-
metric_lower = metric.lower()
|
|
468
538
|
if metric_lower.startswith("recall@"):
|
|
469
539
|
k = int(metric_lower.split("@")[1])
|
|
470
540
|
return compute_recall_at_k(y_true, y_pred, user_ids, k) # type: ignore
|