nextrec 0.3.6__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. nextrec/__init__.py +1 -1
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -5
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/features.py +30 -22
  6. nextrec/basic/layers.py +244 -113
  7. nextrec/basic/loggers.py +62 -43
  8. nextrec/basic/metrics.py +268 -119
  9. nextrec/basic/model.py +1373 -443
  10. nextrec/basic/session.py +10 -3
  11. nextrec/cli.py +498 -0
  12. nextrec/data/__init__.py +19 -25
  13. nextrec/data/batch_utils.py +11 -3
  14. nextrec/data/data_processing.py +42 -24
  15. nextrec/data/data_utils.py +26 -15
  16. nextrec/data/dataloader.py +303 -96
  17. nextrec/data/preprocessor.py +320 -199
  18. nextrec/loss/listwise.py +17 -9
  19. nextrec/loss/loss_utils.py +7 -8
  20. nextrec/loss/pairwise.py +2 -0
  21. nextrec/loss/pointwise.py +30 -12
  22. nextrec/models/generative/hstu.py +106 -40
  23. nextrec/models/match/dssm.py +82 -69
  24. nextrec/models/match/dssm_v2.py +72 -58
  25. nextrec/models/match/mind.py +175 -108
  26. nextrec/models/match/sdm.py +104 -88
  27. nextrec/models/match/youtube_dnn.py +73 -60
  28. nextrec/models/multi_task/esmm.py +53 -39
  29. nextrec/models/multi_task/mmoe.py +70 -47
  30. nextrec/models/multi_task/ple.py +107 -50
  31. nextrec/models/multi_task/poso.py +121 -41
  32. nextrec/models/multi_task/share_bottom.py +54 -38
  33. nextrec/models/ranking/afm.py +172 -45
  34. nextrec/models/ranking/autoint.py +84 -61
  35. nextrec/models/ranking/dcn.py +59 -42
  36. nextrec/models/ranking/dcn_v2.py +64 -23
  37. nextrec/models/ranking/deepfm.py +36 -26
  38. nextrec/models/ranking/dien.py +158 -102
  39. nextrec/models/ranking/din.py +88 -60
  40. nextrec/models/ranking/fibinet.py +55 -35
  41. nextrec/models/ranking/fm.py +32 -26
  42. nextrec/models/ranking/masknet.py +95 -34
  43. nextrec/models/ranking/pnn.py +34 -31
  44. nextrec/models/ranking/widedeep.py +37 -29
  45. nextrec/models/ranking/xdeepfm.py +63 -41
  46. nextrec/utils/__init__.py +61 -32
  47. nextrec/utils/config.py +490 -0
  48. nextrec/utils/device.py +52 -12
  49. nextrec/utils/distributed.py +141 -0
  50. nextrec/utils/embedding.py +1 -0
  51. nextrec/utils/feature.py +1 -0
  52. nextrec/utils/file.py +32 -11
  53. nextrec/utils/initializer.py +61 -16
  54. nextrec/utils/optimizer.py +25 -9
  55. nextrec/utils/synthetic_data.py +531 -0
  56. nextrec/utils/tensor.py +24 -13
  57. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/METADATA +15 -5
  58. nextrec-0.4.2.dist-info/RECORD +69 -0
  59. nextrec-0.4.2.dist-info/entry_points.txt +2 -0
  60. nextrec-0.3.6.dist-info/RECORD +0 -64
  61. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
  62. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/metrics.py CHANGED
@@ -5,22 +5,45 @@ Date: create on 27/10/2025
5
5
  Checkpoint: edit on 02/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
+
8
9
  import logging
9
10
  from typing import Any
10
11
 
11
12
  import numpy as np
12
13
  from sklearn.metrics import (
13
- roc_auc_score, log_loss, mean_squared_error, mean_absolute_error,
14
- accuracy_score, precision_score, recall_score, f1_score, r2_score,
14
+ roc_auc_score,
15
+ log_loss,
16
+ mean_squared_error,
17
+ mean_absolute_error,
18
+ accuracy_score,
19
+ precision_score,
20
+ recall_score,
21
+ f1_score,
22
+ r2_score,
15
23
  )
16
24
 
17
- CLASSIFICATION_METRICS = {'auc', 'gauc', 'ks', 'logloss', 'accuracy', 'acc', 'precision', 'recall', 'f1', 'micro_f1', 'macro_f1'}
18
- REGRESSION_METRICS = {'mse', 'mae', 'rmse', 'r2', 'mape', 'msle'}
25
+ CLASSIFICATION_METRICS = {
26
+ "auc",
27
+ "gauc",
28
+ "ks",
29
+ "logloss",
30
+ "accuracy",
31
+ "acc",
32
+ "precision",
33
+ "recall",
34
+ "f1",
35
+ "micro_f1",
36
+ "macro_f1",
37
+ }
38
+ REGRESSION_METRICS = {"mse", "mae", "rmse", "r2", "mape", "msle"}
19
39
  TASK_DEFAULT_METRICS = {
20
- 'binary': ['auc', 'gauc', 'ks', 'logloss', 'accuracy', 'precision', 'recall', 'f1'],
21
- 'regression': ['mse', 'mae', 'rmse', 'r2', 'mape'],
22
- 'multilabel': ['auc', 'hamming_loss', 'subset_accuracy', 'micro_f1', 'macro_f1'],
23
- 'matching': ['auc', 'gauc', 'precision@10', 'hitrate@10', 'map@10','cosine']+ [f'recall@{k}' for k in (5,10,20)] + [f'ndcg@{k}' for k in (5,10,20)] + [f'mrr@{k}' for k in (5,10,20)]
40
+ "binary": ["auc", "gauc", "ks", "logloss", "accuracy", "precision", "recall", "f1"],
41
+ "regression": ["mse", "mae", "rmse", "r2", "mape"],
42
+ "multilabel": ["auc", "hamming_loss", "subset_accuracy", "micro_f1", "macro_f1"],
43
+ "matching": ["auc", "gauc", "precision@10", "hitrate@10", "map@10", "cosine"]
44
+ + [f"recall@{k}" for k in (5, 10, 20)]
45
+ + [f"ndcg@{k}" for k in (5, 10, 20)]
46
+ + [f"mrr@{k}" for k in (5, 10, 20)],
24
47
  }
25
48
 
26
49
 
@@ -45,18 +68,21 @@ def check_user_id(*metric_sources: Any) -> bool:
45
68
  for name in metric_names:
46
69
  if name == "gauc":
47
70
  return True
48
- if name.startswith(("recall@", "precision@", "hitrate@", "hr@", "mrr@", "ndcg@", "map@")):
71
+ if name.startswith(
72
+ ("recall@", "precision@", "hitrate@", "hr@", "mrr@", "ndcg@", "map@")
73
+ ):
49
74
  return True
50
75
  return False
51
76
 
77
+
52
78
  def compute_ks(y_true: np.ndarray, y_pred: np.ndarray) -> float:
53
79
  """Compute Kolmogorov-Smirnov statistic."""
54
80
  sorted_indices = np.argsort(y_pred)[::-1]
55
81
  y_true_sorted = y_true[sorted_indices]
56
-
82
+
57
83
  n_pos = np.sum(y_true_sorted == 1)
58
84
  n_neg = np.sum(y_true_sorted == 0)
59
-
85
+
60
86
  if n_pos > 0 and n_neg > 0:
61
87
  cum_pos_rate = np.cumsum(y_true_sorted == 1) / n_pos
62
88
  cum_neg_rate = np.cumsum(y_true_sorted == 0) / n_neg
@@ -64,24 +90,34 @@ def compute_ks(y_true: np.ndarray, y_pred: np.ndarray) -> float:
64
90
  return float(ks_value)
65
91
  return 0.0
66
92
 
93
+
67
94
  def compute_mape(y_true: np.ndarray, y_pred: np.ndarray) -> float:
68
95
  """Compute Mean Absolute Percentage Error."""
69
96
  mask = y_true != 0
70
97
  if np.any(mask):
71
- return float(np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100)
98
+ return float(
99
+ np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100
100
+ )
72
101
  return 0.0
73
102
 
103
+
74
104
  def compute_msle(y_true: np.ndarray, y_pred: np.ndarray) -> float:
75
105
  """Compute Mean Squared Log Error."""
76
106
  y_pred_pos = np.maximum(y_pred, 0)
77
107
  return float(mean_squared_error(np.log1p(y_true), np.log1p(y_pred_pos)))
78
108
 
79
- def compute_gauc(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None = None) -> float:
109
+
110
+ def compute_gauc(
111
+ y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None = None
112
+ ) -> float:
80
113
  if user_ids is None:
81
114
  # If no user_ids provided, fall back to regular AUC
82
115
  try:
83
116
  return float(roc_auc_score(y_true, y_pred))
84
- except:
117
+ except Exception as e:
118
+ logging.warning(
119
+ f"[Metrics Warning: GAUC] Failed to compute AUC without user_ids: {e}"
120
+ )
85
121
  return 0.0
86
122
  # Group by user_id and calculate AUC for each user
87
123
  user_aucs = []
@@ -94,12 +130,10 @@ def compute_gauc(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray |
94
130
  # Skip users with only one class (cannot compute AUC)
95
131
  if len(np.unique(user_y_true)) < 2:
96
132
  continue
97
- try:
98
- user_auc = roc_auc_score(user_y_true, user_y_pred)
99
- user_aucs.append(user_auc)
100
- user_weights.append(len(user_y_true))
101
- except:
102
- continue
133
+ user_auc = roc_auc_score(user_y_true, user_y_pred)
134
+ user_aucs.append(user_auc)
135
+ user_weights.append(len(user_y_true))
136
+
103
137
  if len(user_aucs) == 0:
104
138
  return 0.0
105
139
  # Weighted average
@@ -108,22 +142,30 @@ def compute_gauc(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray |
108
142
  gauc = float(np.sum(user_aucs * user_weights) / np.sum(user_weights))
109
143
  return gauc
110
144
 
145
+
111
146
  def group_indices_by_user(user_ids: np.ndarray, n_samples: int) -> list[np.ndarray]:
112
147
  """Group sample indices by user_id. If user_ids is None, treat all as one group."""
113
148
  if user_ids is None:
114
149
  return [np.arange(n_samples)]
115
150
  user_ids = np.asarray(user_ids)
116
151
  if user_ids.shape[0] != n_samples:
117
- logging.warning(f"[Metrics Warning: GAUC] user_ids length {user_ids.shape[0]} != number of samples {n_samples}, treating all samples as a single group for ranking metrics.")
152
+ logging.warning(
153
+ f"[Metrics Warning: GAUC] user_ids length {user_ids.shape[0]} != number of samples {n_samples}, treating all samples as a single group for ranking metrics."
154
+ )
118
155
  return [np.arange(n_samples)]
119
156
  unique_users = np.unique(user_ids)
120
157
  groups = [np.where(user_ids == u)[0] for u in unique_users]
121
158
  return groups
122
159
 
123
- def compute_precision_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
160
+
161
+ def compute_precision_at_k(
162
+ y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int
163
+ ) -> float:
124
164
  """Compute Precision@K."""
125
165
  if user_ids is None:
126
- raise ValueError("[Metrics Error: Precision@K] user_ids must be provided for Precision@K computation.")
166
+ raise ValueError(
167
+ "[Metrics Error: Precision@K] user_ids must be provided for Precision@K computation."
168
+ )
127
169
  y_true = (y_true > 0).astype(int)
128
170
  n = len(y_true)
129
171
  groups = group_indices_by_user(user_ids, n)
@@ -140,10 +182,15 @@ def compute_precision_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.
140
182
  precisions.append(hits / float(k_user))
141
183
  return float(np.mean(precisions)) if precisions else 0.0
142
184
 
143
- def compute_recall_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
185
+
186
+ def compute_recall_at_k(
187
+ y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int
188
+ ) -> float:
144
189
  """Compute Recall@K."""
145
190
  if user_ids is None:
146
- raise ValueError("[Metrics Error: Recall@K] user_ids must be provided for Recall@K computation.")
191
+ raise ValueError(
192
+ "[Metrics Error: Recall@K] user_ids must be provided for Recall@K computation."
193
+ )
147
194
  y_true = (y_true > 0).astype(int)
148
195
  n = len(y_true)
149
196
  groups = group_indices_by_user(user_ids, n)
@@ -163,10 +210,15 @@ def compute_recall_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.nda
163
210
  recalls.append(hits / float(num_pos))
164
211
  return float(np.mean(recalls)) if recalls else 0.0
165
212
 
166
- def compute_hitrate_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
213
+
214
+ def compute_hitrate_at_k(
215
+ y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int
216
+ ) -> float:
167
217
  """Compute HitRate@K."""
168
218
  if user_ids is None:
169
- raise ValueError("[Metrics Error: HitRate@K] user_ids must be provided for HitRate@K computation.")
219
+ raise ValueError(
220
+ "[Metrics Error: HitRate@K] user_ids must be provided for HitRate@K computation."
221
+ )
170
222
  y_true = (y_true > 0).astype(int)
171
223
  n = len(y_true)
172
224
  groups = group_indices_by_user(user_ids, n)
@@ -185,10 +237,15 @@ def compute_hitrate_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.nd
185
237
  hits_per_user.append(1.0 if hits > 0 else 0.0)
186
238
  return float(np.mean(hits_per_user)) if hits_per_user else 0.0
187
239
 
188
- def compute_mrr_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
240
+
241
+ def compute_mrr_at_k(
242
+ y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int
243
+ ) -> float:
189
244
  """Compute MRR@K."""
190
245
  if user_ids is None:
191
- raise ValueError("[Metrics Error: MRR@K] user_ids must be provided for MRR@K computation.")
246
+ raise ValueError(
247
+ "[Metrics Error: MRR@K] user_ids must be provided for MRR@K computation."
248
+ )
192
249
  y_true = (y_true > 0).astype(int)
193
250
  n = len(y_true)
194
251
  groups = group_indices_by_user(user_ids, n)
@@ -212,6 +269,7 @@ def compute_mrr_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarra
212
269
  mrrs.append(rr)
213
270
  return float(np.mean(mrrs)) if mrrs else 0.0
214
271
 
272
+
215
273
  def compute_dcg_at_k(labels: np.ndarray, k: int) -> float:
216
274
  k_user = min(k, labels.size)
217
275
  if k_user == 0:
@@ -220,10 +278,15 @@ def compute_dcg_at_k(labels: np.ndarray, k: int) -> float:
220
278
  discounts = np.log2(np.arange(2, k_user + 2))
221
279
  return float(np.sum(gains / discounts))
222
280
 
223
- def compute_ndcg_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
281
+
282
+ def compute_ndcg_at_k(
283
+ y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int
284
+ ) -> float:
224
285
  """Compute NDCG@K."""
225
286
  if user_ids is None:
226
- raise ValueError("[Metrics Error: NDCG@K] user_ids must be provided for NDCG@K computation.")
287
+ raise ValueError(
288
+ "[Metrics Error: NDCG@K] user_ids must be provided for NDCG@K computation."
289
+ )
227
290
  y_true = (y_true > 0).astype(int)
228
291
  n = len(y_true)
229
292
  groups = group_indices_by_user(user_ids, n)
@@ -247,10 +310,14 @@ def compute_ndcg_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarr
247
310
  return float(np.mean(ndcgs)) if ndcgs else 0.0
248
311
 
249
312
 
250
- def compute_map_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
313
+ def compute_map_at_k(
314
+ y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int
315
+ ) -> float:
251
316
  """Mean Average Precision@K."""
252
317
  if user_ids is None:
253
- raise ValueError("[Metrics Error: MAP@K] user_ids must be provided for MAP@K computation.")
318
+ raise ValueError(
319
+ "[Metrics Error: MAP@K] user_ids must be provided for MAP@K computation."
320
+ )
254
321
  y_true = (y_true > 0).astype(int)
255
322
  n = len(y_true)
256
323
  groups = group_indices_by_user(user_ids, n)
@@ -283,19 +350,21 @@ def compute_cosine_separation(y_true: np.ndarray, y_pred: np.ndarray) -> float:
283
350
  y_true = (y_true > 0).astype(int)
284
351
  pos_mask = y_true == 1
285
352
  neg_mask = y_true == 0
286
-
353
+
287
354
  if not np.any(pos_mask) or not np.any(neg_mask):
288
355
  return 0.0
289
-
356
+
290
357
  pos_mean = float(np.mean(y_pred[pos_mask]))
291
358
  neg_mean = float(np.mean(y_pred[neg_mask]))
292
359
  return pos_mean - neg_mean
293
360
 
294
361
 
295
362
  def configure_metrics(
296
- task: str | list[str], # 'binary' or ['binary', 'regression']
297
- metrics: list[str] | dict[str, list[str]] | None, # ['auc', 'logloss'] or {'task1': ['auc'], 'task2': ['mse']}
298
- target_names: list[str] # ['target1', 'target2']
363
+ task: str | list[str], # 'binary' or ['binary', 'regression']
364
+ metrics: (
365
+ list[str] | dict[str, list[str]] | None
366
+ ), # ['auc', 'logloss'] or {'task1': ['auc'], 'task2': ['mse']}
367
+ target_names: list[str], # ['target1', 'target2']
299
368
  ) -> tuple[list[str], dict[str, list[str]] | None, str]:
300
369
  """Configure metrics based on task and user input."""
301
370
  primary_task = task[0] if isinstance(task, list) else task
@@ -307,7 +376,9 @@ def configure_metrics(
307
376
  task_specific_metrics = {}
308
377
  for task_name, task_metrics in metrics.items():
309
378
  if task_name not in target_names:
310
- logging.warning(f"[Metrics Warning] Task {task_name} not found in targets {target_names}, skipping its metrics")
379
+ logging.warning(
380
+ f"[Metrics Warning] Task {task_name} not found in targets {target_names}, skipping its metrics"
381
+ )
311
382
  continue
312
383
  lowered = [m.lower() for m in task_metrics]
313
384
  task_specific_metrics[task_name] = lowered
@@ -341,114 +412,168 @@ def configure_metrics(
341
412
  best_metrics_mode = getbest_metric_mode(metrics_list[0], primary_task)
342
413
  return metrics_list, task_specific_metrics, best_metrics_mode
343
414
 
415
+
344
416
  def getbest_metric_mode(first_metric: str, primary_task: str) -> str:
345
417
  """Determine if metric should be maximized or minimized."""
346
418
  first_metric_lower = first_metric.lower()
347
419
  # Metrics that should be maximized
348
- if first_metric_lower in {'auc', 'gauc', 'ks', 'accuracy', 'acc', 'precision', 'recall', 'f1', 'r2', 'micro_f1', 'macro_f1'}:
349
- return 'max'
420
+ if first_metric_lower in {
421
+ "auc",
422
+ "gauc",
423
+ "ks",
424
+ "accuracy",
425
+ "acc",
426
+ "precision",
427
+ "recall",
428
+ "f1",
429
+ "r2",
430
+ "micro_f1",
431
+ "macro_f1",
432
+ }:
433
+ return "max"
350
434
  # Ranking metrics that should be maximized (with @K suffix)
351
- if (first_metric_lower.startswith('recall@') or
352
- first_metric_lower.startswith('precision@') or
353
- first_metric_lower.startswith('hitrate@') or
354
- first_metric_lower.startswith('hr@') or
355
- first_metric_lower.startswith('mrr@') or
356
- first_metric_lower.startswith('ndcg@') or
357
- first_metric_lower.startswith('map@')):
358
- return 'max'
435
+ if (
436
+ first_metric_lower.startswith("recall@")
437
+ or first_metric_lower.startswith("precision@")
438
+ or first_metric_lower.startswith("hitrate@")
439
+ or first_metric_lower.startswith("hr@")
440
+ or first_metric_lower.startswith("mrr@")
441
+ or first_metric_lower.startswith("ndcg@")
442
+ or first_metric_lower.startswith("map@")
443
+ ):
444
+ return "max"
359
445
  # Cosine separation should be maximized
360
- if first_metric_lower == 'cosine':
361
- return 'max'
446
+ if first_metric_lower == "cosine":
447
+ return "max"
362
448
  # Metrics that should be minimized
363
- if first_metric_lower in {'logloss', 'mse', 'mae', 'rmse', 'mape', 'msle'}:
364
- return 'min'
449
+ if first_metric_lower in {"logloss", "mse", "mae", "rmse", "mape", "msle"}:
450
+ return "min"
365
451
  # Default based on task type
366
- if primary_task == 'regression':
367
- return 'min'
368
- return 'max'
452
+ if primary_task == "regression":
453
+ return "min"
454
+ return "max"
455
+
369
456
 
370
457
  def compute_single_metric(
371
458
  metric: str,
372
459
  y_true: np.ndarray,
373
460
  y_pred: np.ndarray,
374
461
  task_type: str,
375
- user_ids: np.ndarray | None = None
462
+ user_ids: np.ndarray | None = None,
376
463
  ) -> float:
377
464
  """Compute a single metric given true and predicted values."""
378
465
  y_p_binary = (y_pred > 0.5).astype(int)
379
466
  try:
380
467
  metric_lower = metric.lower()
381
- if metric_lower.startswith('recall@'):
382
- k = int(metric_lower.split('@')[1])
383
- return compute_recall_at_k(y_true, y_pred, user_ids, k) # type: ignore
384
- if metric_lower.startswith('precision@'):
385
- k = int(metric_lower.split('@')[1])
386
- return compute_precision_at_k(y_true, y_pred, user_ids, k) # type: ignore
387
- if metric_lower.startswith('hitrate@') or metric_lower.startswith('hr@'):
388
- k_str = metric_lower.split('@')[1]
468
+ if metric_lower.startswith("recall@"):
469
+ k = int(metric_lower.split("@")[1])
470
+ return compute_recall_at_k(y_true, y_pred, user_ids, k) # type: ignore
471
+ if metric_lower.startswith("precision@"):
472
+ k = int(metric_lower.split("@")[1])
473
+ return compute_precision_at_k(y_true, y_pred, user_ids, k) # type: ignore
474
+ if metric_lower.startswith("hitrate@") or metric_lower.startswith("hr@"):
475
+ k_str = metric_lower.split("@")[1]
389
476
  k = int(k_str)
390
- return compute_hitrate_at_k(y_true, y_pred, user_ids, k) # type: ignore
391
- if metric_lower.startswith('mrr@'):
392
- k = int(metric_lower.split('@')[1])
393
- return compute_mrr_at_k(y_true, y_pred, user_ids, k) # type: ignore
394
- if metric_lower.startswith('ndcg@'):
395
- k = int(metric_lower.split('@')[1])
396
- return compute_ndcg_at_k(y_true, y_pred, user_ids, k) # type: ignore
397
- if metric_lower.startswith('map@'):
398
- k = int(metric_lower.split('@')[1])
399
- return compute_map_at_k(y_true, y_pred, user_ids, k) # type: ignore
477
+ return compute_hitrate_at_k(y_true, y_pred, user_ids, k) # type: ignore
478
+ if metric_lower.startswith("mrr@"):
479
+ k = int(metric_lower.split("@")[1])
480
+ return compute_mrr_at_k(y_true, y_pred, user_ids, k) # type: ignore
481
+ if metric_lower.startswith("ndcg@"):
482
+ k = int(metric_lower.split("@")[1])
483
+ return compute_ndcg_at_k(y_true, y_pred, user_ids, k) # type: ignore
484
+ if metric_lower.startswith("map@"):
485
+ k = int(metric_lower.split("@")[1])
486
+ return compute_map_at_k(y_true, y_pred, user_ids, k) # type: ignore
400
487
  # cosine for matching task
401
- if metric_lower == 'cosine':
488
+ if metric_lower == "cosine":
402
489
  return compute_cosine_separation(y_true, y_pred)
403
- if metric == 'auc':
404
- value = float(roc_auc_score(y_true, y_pred, average='macro' if task_type == 'multilabel' else None))
405
- elif metric == 'gauc':
490
+ if metric == "auc":
491
+ value = float(
492
+ roc_auc_score(
493
+ y_true,
494
+ y_pred,
495
+ average="macro" if task_type == "multilabel" else None,
496
+ )
497
+ )
498
+ elif metric == "gauc":
406
499
  value = float(compute_gauc(y_true, y_pred, user_ids))
407
- elif metric == 'ks':
500
+ elif metric == "ks":
408
501
  value = float(compute_ks(y_true, y_pred))
409
- elif metric == 'logloss':
502
+ elif metric == "logloss":
410
503
  value = float(log_loss(y_true, y_pred))
411
- elif metric in ('accuracy', 'acc'):
504
+ elif metric in ("accuracy", "acc"):
412
505
  value = float(accuracy_score(y_true, y_p_binary))
413
- elif metric == 'precision':
414
- value = float(precision_score(y_true, y_p_binary, average='samples' if task_type == 'multilabel' else 'binary', zero_division=0))
415
- elif metric == 'recall':
416
- value = float(recall_score(y_true, y_p_binary, average='samples' if task_type == 'multilabel' else 'binary', zero_division=0))
417
- elif metric == 'f1':
418
- value = float(f1_score(y_true, y_p_binary, average='samples' if task_type == 'multilabel' else 'binary', zero_division=0))
419
- elif metric == 'micro_f1':
420
- value = float(f1_score(y_true, y_p_binary, average='micro', zero_division=0))
421
- elif metric == 'macro_f1':
422
- value = float(f1_score(y_true, y_p_binary, average='macro', zero_division=0))
423
- elif metric == 'mse':
506
+ elif metric == "precision":
507
+ value = float(
508
+ precision_score(
509
+ y_true,
510
+ y_p_binary,
511
+ average="samples" if task_type == "multilabel" else "binary",
512
+ zero_division=0,
513
+ )
514
+ )
515
+ elif metric == "recall":
516
+ value = float(
517
+ recall_score(
518
+ y_true,
519
+ y_p_binary,
520
+ average="samples" if task_type == "multilabel" else "binary",
521
+ zero_division=0,
522
+ )
523
+ )
524
+ elif metric == "f1":
525
+ value = float(
526
+ f1_score(
527
+ y_true,
528
+ y_p_binary,
529
+ average="samples" if task_type == "multilabel" else "binary",
530
+ zero_division=0,
531
+ )
532
+ )
533
+ elif metric == "micro_f1":
534
+ value = float(
535
+ f1_score(y_true, y_p_binary, average="micro", zero_division=0)
536
+ )
537
+ elif metric == "macro_f1":
538
+ value = float(
539
+ f1_score(y_true, y_p_binary, average="macro", zero_division=0)
540
+ )
541
+ elif metric == "mse":
424
542
  value = float(mean_squared_error(y_true, y_pred))
425
- elif metric == 'mae':
543
+ elif metric == "mae":
426
544
  value = float(mean_absolute_error(y_true, y_pred))
427
- elif metric == 'rmse':
545
+ elif metric == "rmse":
428
546
  value = float(np.sqrt(mean_squared_error(y_true, y_pred)))
429
- elif metric == 'r2':
547
+ elif metric == "r2":
430
548
  value = float(r2_score(y_true, y_pred))
431
- elif metric == 'mape':
549
+ elif metric == "mape":
432
550
  value = float(compute_mape(y_true, y_pred))
433
- elif metric == 'msle':
551
+ elif metric == "msle":
434
552
  value = float(compute_msle(y_true, y_pred))
435
553
  else:
436
- logging.warning(f"[Metric Warning] Metric '{metric}' is not supported, returning 0.0")
554
+ logging.warning(
555
+ f"[Metric Warning] Metric '{metric}' is not supported, returning 0.0"
556
+ )
437
557
  value = 0.0
438
558
  except Exception as exception:
439
- logging.warning(f"[Metric Warning] Failed to compute metric {metric}: {exception}")
559
+ logging.warning(
560
+ f"[Metric Warning] Failed to compute metric {metric}: {exception}"
561
+ )
440
562
  value = 0.0
441
563
  return value
442
564
 
565
+
443
566
  def evaluate_metrics(
444
567
  y_true: np.ndarray | None,
445
568
  y_pred: np.ndarray | None,
446
- metrics: list[str], # example: ['auc', 'logloss']
447
- task: str | list[str], # example: 'binary' or ['binary', 'regression']
448
- target_names: list[str], # example: ['target1', 'target2']
449
- task_specific_metrics: dict[str, list[str]] | None = None, # example: {'target1': ['auc', 'logloss'], 'target2': ['mse']}
450
- user_ids: np.ndarray | None = None # example: User IDs for GAUC computation
451
- ) -> dict: # {'auc': 0.75, 'logloss': 0.45, 'mse_target2': 3.2}
569
+ metrics: list[str], # example: ['auc', 'logloss']
570
+ task: str | list[str], # example: 'binary' or ['binary', 'regression']
571
+ target_names: list[str], # example: ['target1', 'target2']
572
+ task_specific_metrics: (
573
+ dict[str, list[str]] | None
574
+ ) = None, # example: {'target1': ['auc', 'logloss'], 'target2': ['mse']}
575
+ user_ids: np.ndarray | None = None, # example: User IDs for GAUC computation
576
+ ) -> dict: # {'auc': 0.75, 'logloss': 0.45, 'mse_target2': 3.2}
452
577
  """Evaluate specified metrics for given true and predicted values."""
453
578
  result = {}
454
579
  if y_true is None or y_pred is None:
@@ -460,7 +585,9 @@ def evaluate_metrics(
460
585
  if nums_task == 1:
461
586
  for metric in metrics:
462
587
  metric_lower = metric.lower()
463
- value = compute_single_metric(metric_lower, y_true, y_pred, primary_task, user_ids)
588
+ value = compute_single_metric(
589
+ metric_lower, y_true, y_pred, primary_task, user_ids
590
+ )
464
591
  result[metric_lower] = value
465
592
  # Multi-task evaluation
466
593
  else:
@@ -471,7 +598,9 @@ def evaluate_metrics(
471
598
  should_compute = True
472
599
  if task_specific_metrics is not None and task_idx < len(target_names):
473
600
  task_name = target_names[task_idx]
474
- should_compute = metric_lower in task_specific_metrics.get(task_name, [])
601
+ should_compute = metric_lower in task_specific_metrics.get(
602
+ task_name, []
603
+ )
475
604
  else:
476
605
  # Get task type for specific index
477
606
  if isinstance(task, list) and task_idx < len(task):
@@ -479,24 +608,44 @@ def evaluate_metrics(
479
608
  elif isinstance(task, str):
480
609
  task_type = task
481
610
  else:
482
- task_type = 'binary'
483
- if task_type in ['binary', 'multilabel']:
484
- should_compute = metric_lower in {'auc', 'ks', 'logloss', 'accuracy', 'acc', 'precision', 'recall', 'f1', 'micro_f1', 'macro_f1'}
485
- elif task_type == 'regression':
486
- should_compute = metric_lower in {'mse', 'mae', 'rmse', 'r2', 'mape', 'msle'}
611
+ task_type = "binary"
612
+ if task_type in ["binary", "multilabel"]:
613
+ should_compute = metric_lower in {
614
+ "auc",
615
+ "ks",
616
+ "logloss",
617
+ "accuracy",
618
+ "acc",
619
+ "precision",
620
+ "recall",
621
+ "f1",
622
+ "micro_f1",
623
+ "macro_f1",
624
+ }
625
+ elif task_type == "regression":
626
+ should_compute = metric_lower in {
627
+ "mse",
628
+ "mae",
629
+ "rmse",
630
+ "r2",
631
+ "mape",
632
+ "msle",
633
+ }
487
634
  if not should_compute:
488
- continue
489
- target_name = target_names[task_idx]
635
+ continue
636
+ target_name = target_names[task_idx]
490
637
  # Get task type for specific index
491
638
  if isinstance(task, list) and task_idx < len(task):
492
639
  task_type = task[task_idx]
493
640
  elif isinstance(task, str):
494
641
  task_type = task
495
642
  else:
496
- task_type = 'binary'
643
+ task_type = "binary"
497
644
  y_true_task = y_true[:, task_idx]
498
- y_pred_task = y_pred[:, task_idx]
645
+ y_pred_task = y_pred[:, task_idx]
499
646
  # Compute metric
500
- value = compute_single_metric(metric_lower, y_true_task, y_pred_task, task_type, user_ids)
501
- result[f'{metric_lower}_{target_name}'] = value
647
+ value = compute_single_metric(
648
+ metric_lower, y_true_task, y_pred_task, task_type, user_ids
649
+ )
650
+ result[f"{metric_lower}_{target_name}"] = value
502
651
  return result