nextrec 0.4.5__py3-none-any.whl → 0.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/basic/metrics.py CHANGED
@@ -44,6 +44,11 @@ TASK_DEFAULT_METRICS = {
44
44
  + [f"recall@{k}" for k in (5, 10, 20)]
45
45
  + [f"ndcg@{k}" for k in (5, 10, 20)]
46
46
  + [f"mrr@{k}" for k in (5, 10, 20)],
47
+ # generative/multiclass next-item prediction defaults
48
+ "multiclass": ["accuracy"]
49
+ + [f"hitrate@{k}" for k in (1, 5, 10)]
50
+ + [f"recall@{k}" for k in (1, 5, 10)]
51
+ + [f"mrr@{k}" for k in (1, 5, 10)],
47
52
  }
48
53
 
49
54
 
@@ -158,6 +163,51 @@ def group_indices_by_user(user_ids: np.ndarray, n_samples: int) -> list[np.ndarr
158
163
  return groups
159
164
 
160
165
 
166
+ def normalize_multiclass_inputs(
167
+ y_true: np.ndarray, y_pred: np.ndarray
168
+ ) -> tuple[np.ndarray, np.ndarray]:
169
+ """
170
+ Normalize multiclass inputs to consistent shapes.
171
+
172
+ y_true: [N] of class ids
173
+ y_pred: [N, C] of logits/probabilities
174
+ """
175
+ labels = np.asarray(y_true).reshape(-1)
176
+ scores = np.asarray(y_pred)
177
+ if scores.ndim == 1:
178
+ scores = scores.reshape(scores.shape[0], -1)
179
+ if scores.shape[0] != labels.shape[0]:
180
+ raise ValueError(
181
+ f"[Metric Warning] y_true length {labels.shape[0]} != y_pred batch {scores.shape[0]} for multiclass metrics."
182
+ )
183
+ return labels.astype(int), scores
184
+
185
+
186
+ def multiclass_topk_hit_rate(y_true: np.ndarray, y_pred: np.ndarray, k: int) -> float:
187
+ labels, scores = normalize_multiclass_inputs(y_true, y_pred)
188
+ if scores.shape[1] == 0:
189
+ return 0.0
190
+ k = min(k, scores.shape[1])
191
+ topk_idx = np.argpartition(-scores, kth=k - 1, axis=1)[:, :k]
192
+ hits = (topk_idx == labels[:, None]).any(axis=1)
193
+ return float(hits.mean()) if hits.size > 0 else 0.0
194
+
195
+
196
+ def multiclass_mrr_at_k(y_true: np.ndarray, y_pred: np.ndarray, k: int) -> float:
197
+ labels, scores = normalize_multiclass_inputs(y_true, y_pred)
198
+ if scores.shape[1] == 0:
199
+ return 0.0
200
+ k = min(k, scores.shape[1])
201
+ # full sort for stable ranks
202
+ topk_idx = np.argsort(-scores, axis=1)[:, :k]
203
+ ranks = np.full(labels.shape, fill_value=k + 1, dtype=np.float32)
204
+ for idx in range(k):
205
+ match = topk_idx[:, idx] == labels
206
+ ranks[match] = idx + 1
207
+ reciprocals = np.where(ranks <= k, 1.0 / ranks, 0.0)
208
+ return float(reciprocals.mean()) if reciprocals.size > 0 else 0.0
209
+
210
+
161
211
  def compute_precision_at_k(
162
212
  y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int
163
213
  ) -> float:
@@ -463,8 +513,28 @@ def compute_single_metric(
463
513
  ) -> float:
464
514
  """Compute a single metric given true and predicted values."""
465
515
  y_p_binary = (y_pred > 0.5).astype(int)
516
+ metric_lower = metric.lower()
517
+ is_multiclass = task_type == "multiclass" and y_pred.ndim >= 2
518
+ if is_multiclass:
519
+ # Dedicated path for multiclass logits (e.g., next-item prediction)
520
+ labels, scores = normalize_multiclass_inputs(y_true, y_pred)
521
+ if metric_lower in ("accuracy", "acc"):
522
+ preds = scores.argmax(axis=1)
523
+ return float((preds == labels).mean())
524
+ if metric_lower.startswith("hitrate@") or metric_lower.startswith("hr@"):
525
+ k_str = metric_lower.split("@")[1]
526
+ k = int(k_str)
527
+ return multiclass_topk_hit_rate(labels, scores, k)
528
+ if metric_lower.startswith("recall@"):
529
+ k = int(metric_lower.split("@")[1])
530
+ return multiclass_topk_hit_rate(labels, scores, k)
531
+ if metric_lower.startswith("mrr@"):
532
+ k = int(metric_lower.split("@")[1])
533
+ return multiclass_mrr_at_k(labels, scores, k)
534
+ # fall back to accuracy if unsupported metric is requested
535
+ preds = scores.argmax(axis=1)
536
+ return float((preds == labels).mean())
466
537
  try:
467
- metric_lower = metric.lower()
468
538
  if metric_lower.startswith("recall@"):
469
539
  k = int(metric_lower.split("@")[1])
470
540
  return compute_recall_at_k(y_true, y_pred, user_ids, k) # type: ignore