nextrec 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. nextrec/__init__.py +4 -4
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -9
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/dataloader.py +168 -127
  6. nextrec/basic/features.py +24 -27
  7. nextrec/basic/layers.py +328 -159
  8. nextrec/basic/loggers.py +50 -37
  9. nextrec/basic/metrics.py +255 -147
  10. nextrec/basic/model.py +817 -462
  11. nextrec/data/__init__.py +5 -5
  12. nextrec/data/data_utils.py +16 -12
  13. nextrec/data/preprocessor.py +276 -252
  14. nextrec/loss/__init__.py +12 -12
  15. nextrec/loss/loss_utils.py +30 -22
  16. nextrec/loss/match_losses.py +116 -83
  17. nextrec/models/match/__init__.py +5 -5
  18. nextrec/models/match/dssm.py +70 -61
  19. nextrec/models/match/dssm_v2.py +61 -51
  20. nextrec/models/match/mind.py +89 -71
  21. nextrec/models/match/sdm.py +93 -81
  22. nextrec/models/match/youtube_dnn.py +62 -53
  23. nextrec/models/multi_task/esmm.py +49 -43
  24. nextrec/models/multi_task/mmoe.py +65 -56
  25. nextrec/models/multi_task/ple.py +92 -65
  26. nextrec/models/multi_task/share_bottom.py +48 -42
  27. nextrec/models/ranking/__init__.py +7 -7
  28. nextrec/models/ranking/afm.py +39 -30
  29. nextrec/models/ranking/autoint.py +70 -57
  30. nextrec/models/ranking/dcn.py +43 -35
  31. nextrec/models/ranking/deepfm.py +34 -28
  32. nextrec/models/ranking/dien.py +115 -79
  33. nextrec/models/ranking/din.py +84 -60
  34. nextrec/models/ranking/fibinet.py +51 -35
  35. nextrec/models/ranking/fm.py +28 -26
  36. nextrec/models/ranking/masknet.py +31 -31
  37. nextrec/models/ranking/pnn.py +30 -31
  38. nextrec/models/ranking/widedeep.py +36 -31
  39. nextrec/models/ranking/xdeepfm.py +46 -39
  40. nextrec/utils/__init__.py +9 -9
  41. nextrec/utils/embedding.py +1 -1
  42. nextrec/utils/initializer.py +23 -15
  43. nextrec/utils/optimizer.py +14 -10
  44. {nextrec-0.1.1.dist-info → nextrec-0.1.3.dist-info}/METADATA +7 -41
  45. nextrec-0.1.3.dist-info/RECORD +51 -0
  46. nextrec-0.1.1.dist-info/RECORD +0 -51
  47. {nextrec-0.1.1.dist-info → nextrec-0.1.3.dist-info}/WHEEL +0 -0
  48. {nextrec-0.1.1.dist-info → nextrec-0.1.3.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/metrics.py CHANGED
@@ -5,33 +5,55 @@ Date: create on 27/10/2025
5
5
  Author:
6
6
  Yang Zhou,zyaztec@gmail.com
7
7
  """
8
+
8
9
  import logging
9
10
  import numpy as np
10
11
  from sklearn.metrics import (
11
- roc_auc_score, log_loss, mean_squared_error, mean_absolute_error,
12
- accuracy_score, precision_score, recall_score, f1_score, r2_score,
12
+ roc_auc_score,
13
+ log_loss,
14
+ mean_squared_error,
15
+ mean_absolute_error,
16
+ accuracy_score,
17
+ precision_score,
18
+ recall_score,
19
+ f1_score,
20
+ r2_score,
13
21
  )
14
22
 
15
23
 
16
- CLASSIFICATION_METRICS = {'auc', 'gauc', 'ks', 'logloss', 'accuracy', 'acc', 'precision', 'recall', 'f1', 'micro_f1', 'macro_f1'}
17
- REGRESSION_METRICS = {'mse', 'mae', 'rmse', 'r2', 'mape', 'msle'}
24
+ CLASSIFICATION_METRICS = {
25
+ "auc",
26
+ "gauc",
27
+ "ks",
28
+ "logloss",
29
+ "accuracy",
30
+ "acc",
31
+ "precision",
32
+ "recall",
33
+ "f1",
34
+ "micro_f1",
35
+ "macro_f1",
36
+ }
37
+ REGRESSION_METRICS = {"mse", "mae", "rmse", "r2", "mape", "msle"}
18
38
  TASK_DEFAULT_METRICS = {
19
- 'binary': ['auc', 'gauc', 'ks', 'logloss', 'accuracy', 'precision', 'recall', 'f1'],
20
- 'regression': ['mse', 'mae', 'rmse', 'r2', 'mape'],
21
- 'multilabel': ['auc', 'hamming_loss', 'subset_accuracy', 'micro_f1', 'macro_f1'],
22
- 'matching': ['auc', 'gauc', 'precision@10', 'hitrate@10', 'map@10','cosine']+ [f'recall@{k}' for k in (5,10,20)] + [f'ndcg@{k}' for k in (5,10,20)] + [f'mrr@{k}' for k in (5,10,20)]
39
+ "binary": ["auc", "gauc", "ks", "logloss", "accuracy", "precision", "recall", "f1"],
40
+ "regression": ["mse", "mae", "rmse", "r2", "mape"],
41
+ "multilabel": ["auc", "hamming_loss", "subset_accuracy", "micro_f1", "macro_f1"],
42
+ "matching": ["auc", "gauc", "precision@10", "hitrate@10", "map@10", "cosine"]
43
+ + [f"recall@{k}" for k in (5, 10, 20)]
44
+ + [f"ndcg@{k}" for k in (5, 10, 20)]
45
+ + [f"mrr@{k}" for k in (5, 10, 20)],
23
46
  }
24
47
 
25
48
 
26
-
27
49
  def compute_ks(y_true: np.ndarray, y_pred: np.ndarray) -> float:
28
50
  """Compute Kolmogorov-Smirnov statistic."""
29
51
  sorted_indices = np.argsort(y_pred)[::-1]
30
52
  y_true_sorted = y_true[sorted_indices]
31
-
53
+
32
54
  n_pos = np.sum(y_true_sorted == 1)
33
55
  n_neg = np.sum(y_true_sorted == 0)
34
-
56
+
35
57
  if n_pos > 0 and n_neg > 0:
36
58
  cum_pos_rate = np.cumsum(y_true_sorted == 1) / n_pos
37
59
  cum_neg_rate = np.cumsum(y_true_sorted == 0) / n_neg
@@ -44,7 +66,9 @@ def compute_mape(y_true: np.ndarray, y_pred: np.ndarray) -> float:
44
66
  """Compute Mean Absolute Percentage Error."""
45
67
  mask = y_true != 0
46
68
  if np.any(mask):
47
- return float(np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100)
69
+ return float(
70
+ np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100
71
+ )
48
72
  return 0.0
49
73
 
50
74
 
@@ -55,9 +79,7 @@ def compute_msle(y_true: np.ndarray, y_pred: np.ndarray) -> float:
55
79
 
56
80
 
57
81
  def compute_gauc(
58
- y_true: np.ndarray,
59
- y_pred: np.ndarray,
60
- user_ids: np.ndarray | None = None
82
+ y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None = None
61
83
  ) -> float:
62
84
  if user_ids is None:
63
85
  # If no user_ids provided, fall back to regular AUC
@@ -65,37 +87,37 @@ def compute_gauc(
65
87
  return float(roc_auc_score(y_true, y_pred))
66
88
  except:
67
89
  return 0.0
68
-
90
+
69
91
  # Group by user_id and calculate AUC for each user
70
92
  user_aucs = []
71
93
  user_weights = []
72
-
94
+
73
95
  unique_users = np.unique(user_ids)
74
-
96
+
75
97
  for user_id in unique_users:
76
98
  mask = user_ids == user_id
77
99
  user_y_true = y_true[mask]
78
100
  user_y_pred = y_pred[mask]
79
-
101
+
80
102
  # Skip users with only one class (cannot compute AUC)
81
103
  if len(np.unique(user_y_true)) < 2:
82
104
  continue
83
-
105
+
84
106
  try:
85
107
  user_auc = roc_auc_score(user_y_true, user_y_pred)
86
108
  user_aucs.append(user_auc)
87
109
  user_weights.append(len(user_y_true))
88
110
  except:
89
111
  continue
90
-
112
+
91
113
  if len(user_aucs) == 0:
92
114
  return 0.0
93
-
115
+
94
116
  # Weighted average
95
117
  user_aucs = np.array(user_aucs)
96
118
  user_weights = np.array(user_weights)
97
119
  gauc = float(np.sum(user_aucs * user_weights) / np.sum(user_weights))
98
-
120
+
99
121
  return gauc
100
122
 
101
123
 
@@ -103,7 +125,7 @@ def _group_indices_by_user(user_ids: np.ndarray, n_samples: int) -> list[np.ndar
103
125
  """Group sample indices by user_id. If user_ids is None, treat all as one group."""
104
126
  if user_ids is None:
105
127
  return [np.arange(n_samples)]
106
-
128
+
107
129
  user_ids = np.asarray(user_ids)
108
130
  if user_ids.shape[0] != n_samples:
109
131
  logging.warning(
@@ -113,17 +135,19 @@ def _group_indices_by_user(user_ids: np.ndarray, n_samples: int) -> list[np.ndar
113
135
  n_samples,
114
136
  )
115
137
  return [np.arange(n_samples)]
116
-
138
+
117
139
  unique_users = np.unique(user_ids)
118
140
  groups = [np.where(user_ids == u)[0] for u in unique_users]
119
141
  return groups
120
142
 
121
143
 
122
- def _compute_precision_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
144
+ def _compute_precision_at_k(
145
+ y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int
146
+ ) -> float:
123
147
  y_true = (y_true > 0).astype(int)
124
148
  n = len(y_true)
125
149
  groups = _group_indices_by_user(user_ids, n)
126
-
150
+
127
151
  precisions = []
128
152
  for idx in groups:
129
153
  if idx.size == 0:
@@ -135,16 +159,18 @@ def _compute_precision_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np
135
159
  topk = order[:k_user]
136
160
  hits = labels[topk].sum()
137
161
  precisions.append(hits / float(k_user))
138
-
162
+
139
163
  return float(np.mean(precisions)) if precisions else 0.0
140
164
 
141
165
 
142
- def _compute_recall_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
166
+ def _compute_recall_at_k(
167
+ y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int
168
+ ) -> float:
143
169
  """Compute Recall@K."""
144
170
  y_true = (y_true > 0).astype(int)
145
171
  n = len(y_true)
146
172
  groups = _group_indices_by_user(user_ids, n)
147
-
173
+
148
174
  recalls = []
149
175
  for idx in groups:
150
176
  if idx.size == 0:
@@ -159,16 +185,18 @@ def _compute_recall_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.nd
159
185
  topk = order[:k_user]
160
186
  hits = labels[topk].sum()
161
187
  recalls.append(hits / float(num_pos))
162
-
188
+
163
189
  return float(np.mean(recalls)) if recalls else 0.0
164
190
 
165
191
 
166
- def _compute_hitrate_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
192
+ def _compute_hitrate_at_k(
193
+ y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int
194
+ ) -> float:
167
195
  """Compute HitRate@K."""
168
196
  y_true = (y_true > 0).astype(int)
169
197
  n = len(y_true)
170
198
  groups = _group_indices_by_user(user_ids, n)
171
-
199
+
172
200
  hits_per_user = []
173
201
  for idx in groups:
174
202
  if idx.size == 0:
@@ -182,16 +210,18 @@ def _compute_hitrate_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.n
182
210
  topk = order[:k_user]
183
211
  hits = labels[topk].sum()
184
212
  hits_per_user.append(1.0 if hits > 0 else 0.0)
185
-
213
+
186
214
  return float(np.mean(hits_per_user)) if hits_per_user else 0.0
187
215
 
188
216
 
189
- def _compute_mrr_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
217
+ def _compute_mrr_at_k(
218
+ y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int
219
+ ) -> float:
190
220
  """Compute MRR@K."""
191
221
  y_true = (y_true > 0).astype(int)
192
222
  n = len(y_true)
193
223
  groups = _group_indices_by_user(user_ids, n)
194
-
224
+
195
225
  mrrs = []
196
226
  for idx in groups:
197
227
  if idx.size == 0:
@@ -204,14 +234,14 @@ def _compute_mrr_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarr
204
234
  k_user = min(k, idx.size)
205
235
  topk = order[:k_user]
206
236
  ranked_labels = labels[order]
207
-
237
+
208
238
  rr = 0.0
209
239
  for rank, lab in enumerate(ranked_labels[:k_user], start=1):
210
240
  if lab > 0:
211
241
  rr = 1.0 / rank
212
242
  break
213
243
  mrrs.append(rr)
214
-
244
+
215
245
  return float(np.mean(mrrs)) if mrrs else 0.0
216
246
 
217
247
 
@@ -224,12 +254,14 @@ def _compute_dcg_at_k(labels: np.ndarray, k: int) -> float:
224
254
  return float(np.sum(gains / discounts))
225
255
 
226
256
 
227
- def _compute_ndcg_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
257
+ def _compute_ndcg_at_k(
258
+ y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int
259
+ ) -> float:
228
260
  """Compute NDCG@K."""
229
261
  y_true = (y_true > 0).astype(int)
230
262
  n = len(y_true)
231
263
  groups = _group_indices_by_user(user_ids, n)
232
-
264
+
233
265
  ndcgs = []
234
266
  for idx in groups:
235
267
  if idx.size == 0:
@@ -238,27 +270,29 @@ def _compute_ndcg_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndar
238
270
  if labels.sum() == 0:
239
271
  continue
240
272
  scores = y_pred[idx]
241
-
273
+
242
274
  order = np.argsort(scores)[::-1]
243
275
  ranked_labels = labels[order]
244
276
  dcg = _compute_dcg_at_k(ranked_labels, k)
245
-
277
+
246
278
  # ideal DCG
247
279
  ideal_labels = np.sort(labels)[::-1]
248
280
  idcg = _compute_dcg_at_k(ideal_labels, k)
249
281
  if idcg == 0.0:
250
282
  continue
251
283
  ndcgs.append(dcg / idcg)
252
-
284
+
253
285
  return float(np.mean(ndcgs)) if ndcgs else 0.0
254
286
 
255
287
 
256
- def _compute_map_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
288
+ def _compute_map_at_k(
289
+ y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int
290
+ ) -> float:
257
291
  """Mean Average Precision@K."""
258
292
  y_true = (y_true > 0).astype(int)
259
293
  n = len(y_true)
260
294
  groups = _group_indices_by_user(user_ids, n)
261
-
295
+
262
296
  aps = []
263
297
  for idx in groups:
264
298
  if idx.size == 0:
@@ -267,23 +301,23 @@ def _compute_map_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarr
267
301
  num_pos = labels.sum()
268
302
  if num_pos == 0:
269
303
  continue
270
-
304
+
271
305
  scores = y_pred[idx]
272
306
  order = np.argsort(scores)[::-1]
273
307
  k_user = min(k, idx.size)
274
-
308
+
275
309
  hits = 0
276
310
  sum_precisions = 0.0
277
311
  for rank, i in enumerate(order[:k_user], start=1):
278
312
  if labels[i] > 0:
279
313
  hits += 1
280
314
  sum_precisions += hits / float(rank)
281
-
315
+
282
316
  if hits == 0:
283
317
  aps.append(0.0)
284
318
  else:
285
319
  aps.append(sum_precisions / float(num_pos))
286
-
320
+
287
321
  return float(np.mean(aps)) if aps else 0.0
288
322
 
289
323
 
@@ -292,24 +326,26 @@ def _compute_cosine_separation(y_true: np.ndarray, y_pred: np.ndarray) -> float:
292
326
  y_true = (y_true > 0).astype(int)
293
327
  pos_mask = y_true == 1
294
328
  neg_mask = y_true == 0
295
-
329
+
296
330
  if not np.any(pos_mask) or not np.any(neg_mask):
297
331
  return 0.0
298
-
332
+
299
333
  pos_mean = float(np.mean(y_pred[pos_mask]))
300
334
  neg_mean = float(np.mean(y_pred[neg_mask]))
301
335
  return pos_mean - neg_mean
302
336
 
303
337
 
304
338
  def configure_metrics(
305
- task: str | list[str], # 'binary' or ['binary', 'regression']
306
- metrics: list[str] | dict[str, list[str]] | None, # ['auc', 'logloss'] or {'task1': ['auc'], 'task2': ['mse']}
307
- target_names: list[str] # ['target1', 'target2']
339
+ task: str | list[str], # 'binary' or ['binary', 'regression']
340
+ metrics: (
341
+ list[str] | dict[str, list[str]] | None
342
+ ), # ['auc', 'logloss'] or {'task1': ['auc'], 'task2': ['mse']}
343
+ target_names: list[str], # ['target1', 'target2']
308
344
  ) -> tuple[list[str], dict[str, list[str]] | None, str]:
309
345
  """Configure metrics based on task and user input."""
310
346
  primary_task = task[0] if isinstance(task, list) else task
311
347
  nums_task = len(task) if isinstance(task, list) else 1
312
-
348
+
313
349
  metrics_list: list[str] = []
314
350
  task_specific_metrics: dict[str, list[str]] | None = None
315
351
 
@@ -351,48 +387,62 @@ def configure_metrics(
351
387
  if primary_task not in TASK_DEFAULT_METRICS:
352
388
  raise ValueError(f"Unsupported task type: {primary_task}")
353
389
  metrics_list = TASK_DEFAULT_METRICS[primary_task]
354
-
390
+
355
391
  if not metrics_list:
356
392
  # Inline get_default_metrics_for_task logic
357
393
  if primary_task not in TASK_DEFAULT_METRICS:
358
394
  raise ValueError(f"Unsupported task type: {primary_task}")
359
395
  metrics_list = TASK_DEFAULT_METRICS[primary_task]
360
-
396
+
361
397
  best_metrics_mode = get_best_metric_mode(metrics_list[0], primary_task)
362
-
398
+
363
399
  return metrics_list, task_specific_metrics, best_metrics_mode
364
400
 
365
401
 
366
402
  def get_best_metric_mode(first_metric: str, primary_task: str) -> str:
367
403
  """Determine if metric should be maximized or minimized."""
368
404
  first_metric_lower = first_metric.lower()
369
-
405
+
370
406
  # Metrics that should be maximized
371
- if first_metric_lower in {'auc', 'gauc', 'ks', 'accuracy', 'acc', 'precision', 'recall', 'f1', 'r2', 'micro_f1', 'macro_f1'}:
372
- return 'max'
373
-
407
+ if first_metric_lower in {
408
+ "auc",
409
+ "gauc",
410
+ "ks",
411
+ "accuracy",
412
+ "acc",
413
+ "precision",
414
+ "recall",
415
+ "f1",
416
+ "r2",
417
+ "micro_f1",
418
+ "macro_f1",
419
+ }:
420
+ return "max"
421
+
374
422
  # Ranking metrics that should be maximized (with @K suffix)
375
- if (first_metric_lower.startswith('recall@') or
376
- first_metric_lower.startswith('precision@') or
377
- first_metric_lower.startswith('hitrate@') or
378
- first_metric_lower.startswith('hr@') or
379
- first_metric_lower.startswith('mrr@') or
380
- first_metric_lower.startswith('ndcg@') or
381
- first_metric_lower.startswith('map@')):
382
- return 'max'
383
-
423
+ if (
424
+ first_metric_lower.startswith("recall@")
425
+ or first_metric_lower.startswith("precision@")
426
+ or first_metric_lower.startswith("hitrate@")
427
+ or first_metric_lower.startswith("hr@")
428
+ or first_metric_lower.startswith("mrr@")
429
+ or first_metric_lower.startswith("ndcg@")
430
+ or first_metric_lower.startswith("map@")
431
+ ):
432
+ return "max"
433
+
384
434
  # Cosine separation should be maximized
385
- if first_metric_lower == 'cosine':
386
- return 'max'
387
-
435
+ if first_metric_lower == "cosine":
436
+ return "max"
437
+
388
438
  # Metrics that should be minimized
389
- if first_metric_lower in {'logloss', 'mse', 'mae', 'rmse', 'mape', 'msle'}:
390
- return 'min'
391
-
439
+ if first_metric_lower in {"logloss", "mse", "mae", "rmse", "mape", "msle"}:
440
+ return "min"
441
+
392
442
  # Default based on task type
393
- if primary_task == 'regression':
394
- return 'min'
395
- return 'max'
443
+ if primary_task == "regression":
444
+ return "min"
445
+ return "max"
396
446
 
397
447
 
398
448
  def compute_single_metric(
@@ -400,80 +450,111 @@ def compute_single_metric(
400
450
  y_true: np.ndarray,
401
451
  y_pred: np.ndarray,
402
452
  task_type: str,
403
- user_ids: np.ndarray | None = None
453
+ user_ids: np.ndarray | None = None,
404
454
  ) -> float:
405
455
  """Compute a single metric given true and predicted values."""
406
456
  y_p_binary = (y_pred > 0.5).astype(int)
407
-
457
+
408
458
  try:
409
459
  metric_lower = metric.lower()
410
-
460
+
411
461
  # recall@K
412
- if metric_lower.startswith('recall@'):
413
- k = int(metric_lower.split('@')[1])
462
+ if metric_lower.startswith("recall@"):
463
+ k = int(metric_lower.split("@")[1])
414
464
  return _compute_recall_at_k(y_true, y_pred, user_ids, k)
415
465
 
416
466
  # precision@K
417
- if metric_lower.startswith('precision@'):
418
- k = int(metric_lower.split('@')[1])
467
+ if metric_lower.startswith("precision@"):
468
+ k = int(metric_lower.split("@")[1])
419
469
  return _compute_precision_at_k(y_true, y_pred, user_ids, k)
420
470
 
421
471
  # hitrate@K / hr@K
422
- if metric_lower.startswith('hitrate@') or metric_lower.startswith('hr@'):
423
- k_str = metric_lower.split('@')[1]
472
+ if metric_lower.startswith("hitrate@") or metric_lower.startswith("hr@"):
473
+ k_str = metric_lower.split("@")[1]
424
474
  k = int(k_str)
425
475
  return _compute_hitrate_at_k(y_true, y_pred, user_ids, k)
426
476
 
427
477
  # mrr@K
428
- if metric_lower.startswith('mrr@'):
429
- k = int(metric_lower.split('@')[1])
478
+ if metric_lower.startswith("mrr@"):
479
+ k = int(metric_lower.split("@")[1])
430
480
  return _compute_mrr_at_k(y_true, y_pred, user_ids, k)
431
481
 
432
482
  # ndcg@K
433
- if metric_lower.startswith('ndcg@'):
434
- k = int(metric_lower.split('@')[1])
483
+ if metric_lower.startswith("ndcg@"):
484
+ k = int(metric_lower.split("@")[1])
435
485
  return _compute_ndcg_at_k(y_true, y_pred, user_ids, k)
436
486
 
437
487
  # map@K
438
- if metric_lower.startswith('map@'):
439
- k = int(metric_lower.split('@')[1])
488
+ if metric_lower.startswith("map@"):
489
+ k = int(metric_lower.split("@")[1])
440
490
  return _compute_map_at_k(y_true, y_pred, user_ids, k)
441
491
 
442
492
  # cosine for matching task
443
- if metric_lower == 'cosine':
493
+ if metric_lower == "cosine":
444
494
  return _compute_cosine_separation(y_true, y_pred)
445
-
446
- if metric == 'auc':
447
- value = float(roc_auc_score(y_true, y_pred, average='macro' if task_type == 'multilabel' else None))
448
- elif metric == 'gauc':
495
+
496
+ if metric == "auc":
497
+ value = float(
498
+ roc_auc_score(
499
+ y_true,
500
+ y_pred,
501
+ average="macro" if task_type == "multilabel" else None,
502
+ )
503
+ )
504
+ elif metric == "gauc":
449
505
  value = float(compute_gauc(y_true, y_pred, user_ids))
450
- elif metric == 'ks':
506
+ elif metric == "ks":
451
507
  value = float(compute_ks(y_true, y_pred))
452
- elif metric == 'logloss':
508
+ elif metric == "logloss":
453
509
  value = float(log_loss(y_true, y_pred))
454
- elif metric in ('accuracy', 'acc'):
510
+ elif metric in ("accuracy", "acc"):
455
511
  value = float(accuracy_score(y_true, y_p_binary))
456
- elif metric == 'precision':
457
- value = float(precision_score(y_true, y_p_binary, average='samples' if task_type == 'multilabel' else 'binary', zero_division=0))
458
- elif metric == 'recall':
459
- value = float(recall_score(y_true, y_p_binary, average='samples' if task_type == 'multilabel' else 'binary', zero_division=0))
460
- elif metric == 'f1':
461
- value = float(f1_score(y_true, y_p_binary, average='samples' if task_type == 'multilabel' else 'binary', zero_division=0))
462
- elif metric == 'micro_f1':
463
- value = float(f1_score(y_true, y_p_binary, average='micro', zero_division=0))
464
- elif metric == 'macro_f1':
465
- value = float(f1_score(y_true, y_p_binary, average='macro', zero_division=0))
466
- elif metric == 'mse':
512
+ elif metric == "precision":
513
+ value = float(
514
+ precision_score(
515
+ y_true,
516
+ y_p_binary,
517
+ average="samples" if task_type == "multilabel" else "binary",
518
+ zero_division=0,
519
+ )
520
+ )
521
+ elif metric == "recall":
522
+ value = float(
523
+ recall_score(
524
+ y_true,
525
+ y_p_binary,
526
+ average="samples" if task_type == "multilabel" else "binary",
527
+ zero_division=0,
528
+ )
529
+ )
530
+ elif metric == "f1":
531
+ value = float(
532
+ f1_score(
533
+ y_true,
534
+ y_p_binary,
535
+ average="samples" if task_type == "multilabel" else "binary",
536
+ zero_division=0,
537
+ )
538
+ )
539
+ elif metric == "micro_f1":
540
+ value = float(
541
+ f1_score(y_true, y_p_binary, average="micro", zero_division=0)
542
+ )
543
+ elif metric == "macro_f1":
544
+ value = float(
545
+ f1_score(y_true, y_p_binary, average="macro", zero_division=0)
546
+ )
547
+ elif metric == "mse":
467
548
  value = float(mean_squared_error(y_true, y_pred))
468
- elif metric == 'mae':
549
+ elif metric == "mae":
469
550
  value = float(mean_absolute_error(y_true, y_pred))
470
- elif metric == 'rmse':
551
+ elif metric == "rmse":
471
552
  value = float(np.sqrt(mean_squared_error(y_true, y_pred)))
472
- elif metric == 'r2':
553
+ elif metric == "r2":
473
554
  value = float(r2_score(y_true, y_pred))
474
- elif metric == 'mape':
555
+ elif metric == "mape":
475
556
  value = float(compute_mape(y_true, y_pred))
476
- elif metric == 'msle':
557
+ elif metric == "msle":
477
558
  value = float(compute_msle(y_true, y_pred))
478
559
  else:
479
560
  logging.warning(f"Metric '{metric}' is not supported, returning 0.0")
@@ -481,35 +562,40 @@ def compute_single_metric(
481
562
  except Exception as exception:
482
563
  logging.warning(f"Failed to compute metric {metric}: {exception}")
483
564
  value = 0.0
484
-
565
+
485
566
  return value
486
567
 
568
+
487
569
  def evaluate_metrics(
488
570
  y_true: np.ndarray | None,
489
571
  y_pred: np.ndarray | None,
490
- metrics: list[str], # example: ['auc', 'logloss']
491
- task: str | list[str], # example: 'binary' or ['binary', 'regression']
492
- target_names: list[str], # example: ['target1', 'target2']
493
- task_specific_metrics: dict[str, list[str]] | None = None, # example: {'target1': ['auc', 'logloss'], 'target2': ['mse']}
494
- user_ids: np.ndarray | None = None # example: User IDs for GAUC computation
495
- ) -> dict: # {'auc': 0.75, 'logloss': 0.45, 'mse_target2': 3.2}
572
+ metrics: list[str], # example: ['auc', 'logloss']
573
+ task: str | list[str], # example: 'binary' or ['binary', 'regression']
574
+ target_names: list[str], # example: ['target1', 'target2']
575
+ task_specific_metrics: (
576
+ dict[str, list[str]] | None
577
+ ) = None, # example: {'target1': ['auc', 'logloss'], 'target2': ['mse']}
578
+ user_ids: np.ndarray | None = None, # example: User IDs for GAUC computation
579
+ ) -> dict: # {'auc': 0.75, 'logloss': 0.45, 'mse_target2': 3.2}
496
580
  """Evaluate specified metrics for given true and predicted values."""
497
581
  result = {}
498
-
582
+
499
583
  if y_true is None or y_pred is None:
500
584
  return result
501
-
585
+
502
586
  # Main evaluation logic
503
587
  primary_task = task[0] if isinstance(task, list) else task
504
588
  nums_task = len(task) if isinstance(task, list) else 1
505
-
589
+
506
590
  # Single task evaluation
507
591
  if nums_task == 1:
508
592
  for metric in metrics:
509
593
  metric_lower = metric.lower()
510
- value = compute_single_metric(metric_lower, y_true, y_pred, primary_task, user_ids)
594
+ value = compute_single_metric(
595
+ metric_lower, y_true, y_pred, primary_task, user_ids
596
+ )
511
597
  result[metric_lower] = value
512
-
598
+
513
599
  # Multi-task evaluation
514
600
  else:
515
601
  for metric in metrics:
@@ -519,7 +605,9 @@ def evaluate_metrics(
519
605
  should_compute = True
520
606
  if task_specific_metrics is not None and task_idx < len(target_names):
521
607
  task_name = target_names[task_idx]
522
- should_compute = metric_lower in task_specific_metrics.get(task_name, [])
608
+ should_compute = metric_lower in task_specific_metrics.get(
609
+ task_name, []
610
+ )
523
611
  else:
524
612
  # Get task type for specific index
525
613
  if isinstance(task, list) and task_idx < len(task):
@@ -527,31 +615,51 @@ def evaluate_metrics(
527
615
  elif isinstance(task, str):
528
616
  task_type = task
529
617
  else:
530
- task_type = 'binary'
531
-
532
- if task_type in ['binary', 'multilabel']:
533
- should_compute = metric_lower in {'auc', 'ks', 'logloss', 'accuracy', 'acc', 'precision', 'recall', 'f1', 'micro_f1', 'macro_f1'}
534
- elif task_type == 'regression':
535
- should_compute = metric_lower in {'mse', 'mae', 'rmse', 'r2', 'mape', 'msle'}
536
-
618
+ task_type = "binary"
619
+
620
+ if task_type in ["binary", "multilabel"]:
621
+ should_compute = metric_lower in {
622
+ "auc",
623
+ "ks",
624
+ "logloss",
625
+ "accuracy",
626
+ "acc",
627
+ "precision",
628
+ "recall",
629
+ "f1",
630
+ "micro_f1",
631
+ "macro_f1",
632
+ }
633
+ elif task_type == "regression":
634
+ should_compute = metric_lower in {
635
+ "mse",
636
+ "mae",
637
+ "rmse",
638
+ "r2",
639
+ "mape",
640
+ "msle",
641
+ }
642
+
537
643
  if not should_compute:
538
644
  continue
539
-
645
+
540
646
  target_name = target_names[task_idx]
541
-
647
+
542
648
  # Get task type for specific index
543
649
  if isinstance(task, list) and task_idx < len(task):
544
650
  task_type = task[task_idx]
545
651
  elif isinstance(task, str):
546
652
  task_type = task
547
653
  else:
548
- task_type = 'binary'
549
-
654
+ task_type = "binary"
655
+
550
656
  y_true_task = y_true[:, task_idx]
551
657
  y_pred_task = y_pred[:, task_idx]
552
-
658
+
553
659
  # Compute metric
554
- value = compute_single_metric(metric_lower, y_true_task, y_pred_task, task_type, user_ids)
555
- result[f'{metric_lower}_{target_name}'] = value
556
-
660
+ value = compute_single_metric(
661
+ metric_lower, y_true_task, y_pred_task, task_type, user_ids
662
+ )
663
+ result[f"{metric_lower}_{target_name}"] = value
664
+
557
665
  return result