nextrec 0.1.4__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. nextrec/__init__.py +4 -4
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +9 -10
  4. nextrec/basic/callback.py +0 -1
  5. nextrec/basic/dataloader.py +127 -168
  6. nextrec/basic/features.py +27 -24
  7. nextrec/basic/layers.py +159 -328
  8. nextrec/basic/loggers.py +37 -50
  9. nextrec/basic/metrics.py +147 -255
  10. nextrec/basic/model.py +462 -817
  11. nextrec/data/__init__.py +5 -5
  12. nextrec/data/data_utils.py +12 -16
  13. nextrec/data/preprocessor.py +252 -276
  14. nextrec/loss/__init__.py +12 -12
  15. nextrec/loss/loss_utils.py +22 -30
  16. nextrec/loss/match_losses.py +83 -116
  17. nextrec/models/match/__init__.py +5 -5
  18. nextrec/models/match/dssm.py +61 -70
  19. nextrec/models/match/dssm_v2.py +51 -61
  20. nextrec/models/match/mind.py +71 -89
  21. nextrec/models/match/sdm.py +81 -93
  22. nextrec/models/match/youtube_dnn.py +53 -62
  23. nextrec/models/multi_task/esmm.py +43 -49
  24. nextrec/models/multi_task/mmoe.py +56 -65
  25. nextrec/models/multi_task/ple.py +65 -92
  26. nextrec/models/multi_task/share_bottom.py +42 -48
  27. nextrec/models/ranking/__init__.py +7 -7
  28. nextrec/models/ranking/afm.py +30 -39
  29. nextrec/models/ranking/autoint.py +57 -70
  30. nextrec/models/ranking/dcn.py +35 -43
  31. nextrec/models/ranking/deepfm.py +28 -34
  32. nextrec/models/ranking/dien.py +79 -115
  33. nextrec/models/ranking/din.py +60 -84
  34. nextrec/models/ranking/fibinet.py +35 -51
  35. nextrec/models/ranking/fm.py +26 -28
  36. nextrec/models/ranking/masknet.py +31 -31
  37. nextrec/models/ranking/pnn.py +31 -30
  38. nextrec/models/ranking/widedeep.py +31 -36
  39. nextrec/models/ranking/xdeepfm.py +39 -46
  40. nextrec/utils/__init__.py +9 -9
  41. nextrec/utils/embedding.py +1 -1
  42. nextrec/utils/initializer.py +15 -23
  43. nextrec/utils/optimizer.py +10 -14
  44. {nextrec-0.1.4.dist-info → nextrec-0.1.7.dist-info}/METADATA +16 -7
  45. nextrec-0.1.7.dist-info/RECORD +51 -0
  46. nextrec-0.1.4.dist-info/RECORD +0 -51
  47. {nextrec-0.1.4.dist-info → nextrec-0.1.7.dist-info}/WHEEL +0 -0
  48. {nextrec-0.1.4.dist-info → nextrec-0.1.7.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/metrics.py CHANGED
@@ -5,55 +5,33 @@ Date: create on 27/10/2025
5
5
  Author:
6
6
  Yang Zhou,zyaztec@gmail.com
7
7
  """
8
-
9
8
  import logging
10
9
  import numpy as np
11
10
  from sklearn.metrics import (
12
- roc_auc_score,
13
- log_loss,
14
- mean_squared_error,
15
- mean_absolute_error,
16
- accuracy_score,
17
- precision_score,
18
- recall_score,
19
- f1_score,
20
- r2_score,
11
+ roc_auc_score, log_loss, mean_squared_error, mean_absolute_error,
12
+ accuracy_score, precision_score, recall_score, f1_score, r2_score,
21
13
  )
22
14
 
23
15
 
24
- CLASSIFICATION_METRICS = {
25
- "auc",
26
- "gauc",
27
- "ks",
28
- "logloss",
29
- "accuracy",
30
- "acc",
31
- "precision",
32
- "recall",
33
- "f1",
34
- "micro_f1",
35
- "macro_f1",
36
- }
37
- REGRESSION_METRICS = {"mse", "mae", "rmse", "r2", "mape", "msle"}
16
+ CLASSIFICATION_METRICS = {'auc', 'gauc', 'ks', 'logloss', 'accuracy', 'acc', 'precision', 'recall', 'f1', 'micro_f1', 'macro_f1'}
17
+ REGRESSION_METRICS = {'mse', 'mae', 'rmse', 'r2', 'mape', 'msle'}
38
18
  TASK_DEFAULT_METRICS = {
39
- "binary": ["auc", "gauc", "ks", "logloss", "accuracy", "precision", "recall", "f1"],
40
- "regression": ["mse", "mae", "rmse", "r2", "mape"],
41
- "multilabel": ["auc", "hamming_loss", "subset_accuracy", "micro_f1", "macro_f1"],
42
- "matching": ["auc", "gauc", "precision@10", "hitrate@10", "map@10", "cosine"]
43
- + [f"recall@{k}" for k in (5, 10, 20)]
44
- + [f"ndcg@{k}" for k in (5, 10, 20)]
45
- + [f"mrr@{k}" for k in (5, 10, 20)],
19
+ 'binary': ['auc', 'gauc', 'ks', 'logloss', 'accuracy', 'precision', 'recall', 'f1'],
20
+ 'regression': ['mse', 'mae', 'rmse', 'r2', 'mape'],
21
+ 'multilabel': ['auc', 'hamming_loss', 'subset_accuracy', 'micro_f1', 'macro_f1'],
22
+ 'matching': ['auc', 'gauc', 'precision@10', 'hitrate@10', 'map@10','cosine']+ [f'recall@{k}' for k in (5,10,20)] + [f'ndcg@{k}' for k in (5,10,20)] + [f'mrr@{k}' for k in (5,10,20)]
46
23
  }
47
24
 
48
25
 
26
+
49
27
  def compute_ks(y_true: np.ndarray, y_pred: np.ndarray) -> float:
50
28
  """Compute Kolmogorov-Smirnov statistic."""
51
29
  sorted_indices = np.argsort(y_pred)[::-1]
52
30
  y_true_sorted = y_true[sorted_indices]
53
-
31
+
54
32
  n_pos = np.sum(y_true_sorted == 1)
55
33
  n_neg = np.sum(y_true_sorted == 0)
56
-
34
+
57
35
  if n_pos > 0 and n_neg > 0:
58
36
  cum_pos_rate = np.cumsum(y_true_sorted == 1) / n_pos
59
37
  cum_neg_rate = np.cumsum(y_true_sorted == 0) / n_neg
@@ -66,9 +44,7 @@ def compute_mape(y_true: np.ndarray, y_pred: np.ndarray) -> float:
66
44
  """Compute Mean Absolute Percentage Error."""
67
45
  mask = y_true != 0
68
46
  if np.any(mask):
69
- return float(
70
- np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100
71
- )
47
+ return float(np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100)
72
48
  return 0.0
73
49
 
74
50
 
@@ -79,7 +55,9 @@ def compute_msle(y_true: np.ndarray, y_pred: np.ndarray) -> float:
79
55
 
80
56
 
81
57
  def compute_gauc(
82
- y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None = None
58
+ y_true: np.ndarray,
59
+ y_pred: np.ndarray,
60
+ user_ids: np.ndarray | None = None
83
61
  ) -> float:
84
62
  if user_ids is None:
85
63
  # If no user_ids provided, fall back to regular AUC
@@ -87,37 +65,37 @@ def compute_gauc(
87
65
  return float(roc_auc_score(y_true, y_pred))
88
66
  except:
89
67
  return 0.0
90
-
68
+
91
69
  # Group by user_id and calculate AUC for each user
92
70
  user_aucs = []
93
71
  user_weights = []
94
-
72
+
95
73
  unique_users = np.unique(user_ids)
96
-
74
+
97
75
  for user_id in unique_users:
98
76
  mask = user_ids == user_id
99
77
  user_y_true = y_true[mask]
100
78
  user_y_pred = y_pred[mask]
101
-
79
+
102
80
  # Skip users with only one class (cannot compute AUC)
103
81
  if len(np.unique(user_y_true)) < 2:
104
82
  continue
105
-
83
+
106
84
  try:
107
85
  user_auc = roc_auc_score(user_y_true, user_y_pred)
108
86
  user_aucs.append(user_auc)
109
87
  user_weights.append(len(user_y_true))
110
88
  except:
111
89
  continue
112
-
90
+
113
91
  if len(user_aucs) == 0:
114
92
  return 0.0
115
-
93
+
116
94
  # Weighted average
117
95
  user_aucs = np.array(user_aucs)
118
96
  user_weights = np.array(user_weights)
119
97
  gauc = float(np.sum(user_aucs * user_weights) / np.sum(user_weights))
120
-
98
+
121
99
  return gauc
122
100
 
123
101
 
@@ -125,7 +103,7 @@ def _group_indices_by_user(user_ids: np.ndarray, n_samples: int) -> list[np.ndar
125
103
  """Group sample indices by user_id. If user_ids is None, treat all as one group."""
126
104
  if user_ids is None:
127
105
  return [np.arange(n_samples)]
128
-
106
+
129
107
  user_ids = np.asarray(user_ids)
130
108
  if user_ids.shape[0] != n_samples:
131
109
  logging.warning(
@@ -135,19 +113,17 @@ def _group_indices_by_user(user_ids: np.ndarray, n_samples: int) -> list[np.ndar
135
113
  n_samples,
136
114
  )
137
115
  return [np.arange(n_samples)]
138
-
116
+
139
117
  unique_users = np.unique(user_ids)
140
118
  groups = [np.where(user_ids == u)[0] for u in unique_users]
141
119
  return groups
142
120
 
143
121
 
144
- def _compute_precision_at_k(
145
- y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int
146
- ) -> float:
122
+ def _compute_precision_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
147
123
  y_true = (y_true > 0).astype(int)
148
124
  n = len(y_true)
149
125
  groups = _group_indices_by_user(user_ids, n)
150
-
126
+
151
127
  precisions = []
152
128
  for idx in groups:
153
129
  if idx.size == 0:
@@ -159,18 +135,16 @@ def _compute_precision_at_k(
159
135
  topk = order[:k_user]
160
136
  hits = labels[topk].sum()
161
137
  precisions.append(hits / float(k_user))
162
-
138
+
163
139
  return float(np.mean(precisions)) if precisions else 0.0
164
140
 
165
141
 
166
- def _compute_recall_at_k(
167
- y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int
168
- ) -> float:
142
+ def _compute_recall_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
169
143
  """Compute Recall@K."""
170
144
  y_true = (y_true > 0).astype(int)
171
145
  n = len(y_true)
172
146
  groups = _group_indices_by_user(user_ids, n)
173
-
147
+
174
148
  recalls = []
175
149
  for idx in groups:
176
150
  if idx.size == 0:
@@ -185,18 +159,16 @@ def _compute_recall_at_k(
185
159
  topk = order[:k_user]
186
160
  hits = labels[topk].sum()
187
161
  recalls.append(hits / float(num_pos))
188
-
162
+
189
163
  return float(np.mean(recalls)) if recalls else 0.0
190
164
 
191
165
 
192
- def _compute_hitrate_at_k(
193
- y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int
194
- ) -> float:
166
+ def _compute_hitrate_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
195
167
  """Compute HitRate@K."""
196
168
  y_true = (y_true > 0).astype(int)
197
169
  n = len(y_true)
198
170
  groups = _group_indices_by_user(user_ids, n)
199
-
171
+
200
172
  hits_per_user = []
201
173
  for idx in groups:
202
174
  if idx.size == 0:
@@ -210,18 +182,16 @@ def _compute_hitrate_at_k(
210
182
  topk = order[:k_user]
211
183
  hits = labels[topk].sum()
212
184
  hits_per_user.append(1.0 if hits > 0 else 0.0)
213
-
185
+
214
186
  return float(np.mean(hits_per_user)) if hits_per_user else 0.0
215
187
 
216
188
 
217
- def _compute_mrr_at_k(
218
- y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int
219
- ) -> float:
189
+ def _compute_mrr_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
220
190
  """Compute MRR@K."""
221
191
  y_true = (y_true > 0).astype(int)
222
192
  n = len(y_true)
223
193
  groups = _group_indices_by_user(user_ids, n)
224
-
194
+
225
195
  mrrs = []
226
196
  for idx in groups:
227
197
  if idx.size == 0:
@@ -234,14 +204,14 @@ def _compute_mrr_at_k(
234
204
  k_user = min(k, idx.size)
235
205
  topk = order[:k_user]
236
206
  ranked_labels = labels[order]
237
-
207
+
238
208
  rr = 0.0
239
209
  for rank, lab in enumerate(ranked_labels[:k_user], start=1):
240
210
  if lab > 0:
241
211
  rr = 1.0 / rank
242
212
  break
243
213
  mrrs.append(rr)
244
-
214
+
245
215
  return float(np.mean(mrrs)) if mrrs else 0.0
246
216
 
247
217
 
@@ -254,14 +224,12 @@ def _compute_dcg_at_k(labels: np.ndarray, k: int) -> float:
254
224
  return float(np.sum(gains / discounts))
255
225
 
256
226
 
257
- def _compute_ndcg_at_k(
258
- y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int
259
- ) -> float:
227
+ def _compute_ndcg_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
260
228
  """Compute NDCG@K."""
261
229
  y_true = (y_true > 0).astype(int)
262
230
  n = len(y_true)
263
231
  groups = _group_indices_by_user(user_ids, n)
264
-
232
+
265
233
  ndcgs = []
266
234
  for idx in groups:
267
235
  if idx.size == 0:
@@ -270,29 +238,27 @@ def _compute_ndcg_at_k(
270
238
  if labels.sum() == 0:
271
239
  continue
272
240
  scores = y_pred[idx]
273
-
241
+
274
242
  order = np.argsort(scores)[::-1]
275
243
  ranked_labels = labels[order]
276
244
  dcg = _compute_dcg_at_k(ranked_labels, k)
277
-
245
+
278
246
  # ideal DCG
279
247
  ideal_labels = np.sort(labels)[::-1]
280
248
  idcg = _compute_dcg_at_k(ideal_labels, k)
281
249
  if idcg == 0.0:
282
250
  continue
283
251
  ndcgs.append(dcg / idcg)
284
-
252
+
285
253
  return float(np.mean(ndcgs)) if ndcgs else 0.0
286
254
 
287
255
 
288
- def _compute_map_at_k(
289
- y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int
290
- ) -> float:
256
+ def _compute_map_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
291
257
  """Mean Average Precision@K."""
292
258
  y_true = (y_true > 0).astype(int)
293
259
  n = len(y_true)
294
260
  groups = _group_indices_by_user(user_ids, n)
295
-
261
+
296
262
  aps = []
297
263
  for idx in groups:
298
264
  if idx.size == 0:
@@ -301,23 +267,23 @@ def _compute_map_at_k(
301
267
  num_pos = labels.sum()
302
268
  if num_pos == 0:
303
269
  continue
304
-
270
+
305
271
  scores = y_pred[idx]
306
272
  order = np.argsort(scores)[::-1]
307
273
  k_user = min(k, idx.size)
308
-
274
+
309
275
  hits = 0
310
276
  sum_precisions = 0.0
311
277
  for rank, i in enumerate(order[:k_user], start=1):
312
278
  if labels[i] > 0:
313
279
  hits += 1
314
280
  sum_precisions += hits / float(rank)
315
-
281
+
316
282
  if hits == 0:
317
283
  aps.append(0.0)
318
284
  else:
319
285
  aps.append(sum_precisions / float(num_pos))
320
-
286
+
321
287
  return float(np.mean(aps)) if aps else 0.0
322
288
 
323
289
 
@@ -326,26 +292,24 @@ def _compute_cosine_separation(y_true: np.ndarray, y_pred: np.ndarray) -> float:
326
292
  y_true = (y_true > 0).astype(int)
327
293
  pos_mask = y_true == 1
328
294
  neg_mask = y_true == 0
329
-
295
+
330
296
  if not np.any(pos_mask) or not np.any(neg_mask):
331
297
  return 0.0
332
-
298
+
333
299
  pos_mean = float(np.mean(y_pred[pos_mask]))
334
300
  neg_mean = float(np.mean(y_pred[neg_mask]))
335
301
  return pos_mean - neg_mean
336
302
 
337
303
 
338
304
  def configure_metrics(
339
- task: str | list[str], # 'binary' or ['binary', 'regression']
340
- metrics: (
341
- list[str] | dict[str, list[str]] | None
342
- ), # ['auc', 'logloss'] or {'task1': ['auc'], 'task2': ['mse']}
343
- target_names: list[str], # ['target1', 'target2']
305
+ task: str | list[str], # 'binary' or ['binary', 'regression']
306
+ metrics: list[str] | dict[str, list[str]] | None, # ['auc', 'logloss'] or {'task1': ['auc'], 'task2': ['mse']}
307
+ target_names: list[str] # ['target1', 'target2']
344
308
  ) -> tuple[list[str], dict[str, list[str]] | None, str]:
345
309
  """Configure metrics based on task and user input."""
346
310
  primary_task = task[0] if isinstance(task, list) else task
347
311
  nums_task = len(task) if isinstance(task, list) else 1
348
-
312
+
349
313
  metrics_list: list[str] = []
350
314
  task_specific_metrics: dict[str, list[str]] | None = None
351
315
 
@@ -387,62 +351,48 @@ def configure_metrics(
387
351
  if primary_task not in TASK_DEFAULT_METRICS:
388
352
  raise ValueError(f"Unsupported task type: {primary_task}")
389
353
  metrics_list = TASK_DEFAULT_METRICS[primary_task]
390
-
354
+
391
355
  if not metrics_list:
392
356
  # Inline get_default_metrics_for_task logic
393
357
  if primary_task not in TASK_DEFAULT_METRICS:
394
358
  raise ValueError(f"Unsupported task type: {primary_task}")
395
359
  metrics_list = TASK_DEFAULT_METRICS[primary_task]
396
-
360
+
397
361
  best_metrics_mode = get_best_metric_mode(metrics_list[0], primary_task)
398
-
362
+
399
363
  return metrics_list, task_specific_metrics, best_metrics_mode
400
364
 
401
365
 
402
366
  def get_best_metric_mode(first_metric: str, primary_task: str) -> str:
403
367
  """Determine if metric should be maximized or minimized."""
404
368
  first_metric_lower = first_metric.lower()
405
-
369
+
406
370
  # Metrics that should be maximized
407
- if first_metric_lower in {
408
- "auc",
409
- "gauc",
410
- "ks",
411
- "accuracy",
412
- "acc",
413
- "precision",
414
- "recall",
415
- "f1",
416
- "r2",
417
- "micro_f1",
418
- "macro_f1",
419
- }:
420
- return "max"
421
-
371
+ if first_metric_lower in {'auc', 'gauc', 'ks', 'accuracy', 'acc', 'precision', 'recall', 'f1', 'r2', 'micro_f1', 'macro_f1'}:
372
+ return 'max'
373
+
422
374
  # Ranking metrics that should be maximized (with @K suffix)
423
- if (
424
- first_metric_lower.startswith("recall@")
425
- or first_metric_lower.startswith("precision@")
426
- or first_metric_lower.startswith("hitrate@")
427
- or first_metric_lower.startswith("hr@")
428
- or first_metric_lower.startswith("mrr@")
429
- or first_metric_lower.startswith("ndcg@")
430
- or first_metric_lower.startswith("map@")
431
- ):
432
- return "max"
433
-
375
+ if (first_metric_lower.startswith('recall@') or
376
+ first_metric_lower.startswith('precision@') or
377
+ first_metric_lower.startswith('hitrate@') or
378
+ first_metric_lower.startswith('hr@') or
379
+ first_metric_lower.startswith('mrr@') or
380
+ first_metric_lower.startswith('ndcg@') or
381
+ first_metric_lower.startswith('map@')):
382
+ return 'max'
383
+
434
384
  # Cosine separation should be maximized
435
- if first_metric_lower == "cosine":
436
- return "max"
437
-
385
+ if first_metric_lower == 'cosine':
386
+ return 'max'
387
+
438
388
  # Metrics that should be minimized
439
- if first_metric_lower in {"logloss", "mse", "mae", "rmse", "mape", "msle"}:
440
- return "min"
441
-
389
+ if first_metric_lower in {'logloss', 'mse', 'mae', 'rmse', 'mape', 'msle'}:
390
+ return 'min'
391
+
442
392
  # Default based on task type
443
- if primary_task == "regression":
444
- return "min"
445
- return "max"
393
+ if primary_task == 'regression':
394
+ return 'min'
395
+ return 'max'
446
396
 
447
397
 
448
398
  def compute_single_metric(
@@ -450,111 +400,80 @@ def compute_single_metric(
450
400
  y_true: np.ndarray,
451
401
  y_pred: np.ndarray,
452
402
  task_type: str,
453
- user_ids: np.ndarray | None = None,
403
+ user_ids: np.ndarray | None = None
454
404
  ) -> float:
455
405
  """Compute a single metric given true and predicted values."""
456
406
  y_p_binary = (y_pred > 0.5).astype(int)
457
-
407
+
458
408
  try:
459
409
  metric_lower = metric.lower()
460
-
410
+
461
411
  # recall@K
462
- if metric_lower.startswith("recall@"):
463
- k = int(metric_lower.split("@")[1])
412
+ if metric_lower.startswith('recall@'):
413
+ k = int(metric_lower.split('@')[1])
464
414
  return _compute_recall_at_k(y_true, y_pred, user_ids, k)
465
415
 
466
416
  # precision@K
467
- if metric_lower.startswith("precision@"):
468
- k = int(metric_lower.split("@")[1])
417
+ if metric_lower.startswith('precision@'):
418
+ k = int(metric_lower.split('@')[1])
469
419
  return _compute_precision_at_k(y_true, y_pred, user_ids, k)
470
420
 
471
421
  # hitrate@K / hr@K
472
- if metric_lower.startswith("hitrate@") or metric_lower.startswith("hr@"):
473
- k_str = metric_lower.split("@")[1]
422
+ if metric_lower.startswith('hitrate@') or metric_lower.startswith('hr@'):
423
+ k_str = metric_lower.split('@')[1]
474
424
  k = int(k_str)
475
425
  return _compute_hitrate_at_k(y_true, y_pred, user_ids, k)
476
426
 
477
427
  # mrr@K
478
- if metric_lower.startswith("mrr@"):
479
- k = int(metric_lower.split("@")[1])
428
+ if metric_lower.startswith('mrr@'):
429
+ k = int(metric_lower.split('@')[1])
480
430
  return _compute_mrr_at_k(y_true, y_pred, user_ids, k)
481
431
 
482
432
  # ndcg@K
483
- if metric_lower.startswith("ndcg@"):
484
- k = int(metric_lower.split("@")[1])
433
+ if metric_lower.startswith('ndcg@'):
434
+ k = int(metric_lower.split('@')[1])
485
435
  return _compute_ndcg_at_k(y_true, y_pred, user_ids, k)
486
436
 
487
437
  # map@K
488
- if metric_lower.startswith("map@"):
489
- k = int(metric_lower.split("@")[1])
438
+ if metric_lower.startswith('map@'):
439
+ k = int(metric_lower.split('@')[1])
490
440
  return _compute_map_at_k(y_true, y_pred, user_ids, k)
491
441
 
492
442
  # cosine for matching task
493
- if metric_lower == "cosine":
443
+ if metric_lower == 'cosine':
494
444
  return _compute_cosine_separation(y_true, y_pred)
495
-
496
- if metric == "auc":
497
- value = float(
498
- roc_auc_score(
499
- y_true,
500
- y_pred,
501
- average="macro" if task_type == "multilabel" else None,
502
- )
503
- )
504
- elif metric == "gauc":
445
+
446
+ if metric == 'auc':
447
+ value = float(roc_auc_score(y_true, y_pred, average='macro' if task_type == 'multilabel' else None))
448
+ elif metric == 'gauc':
505
449
  value = float(compute_gauc(y_true, y_pred, user_ids))
506
- elif metric == "ks":
450
+ elif metric == 'ks':
507
451
  value = float(compute_ks(y_true, y_pred))
508
- elif metric == "logloss":
452
+ elif metric == 'logloss':
509
453
  value = float(log_loss(y_true, y_pred))
510
- elif metric in ("accuracy", "acc"):
454
+ elif metric in ('accuracy', 'acc'):
511
455
  value = float(accuracy_score(y_true, y_p_binary))
512
- elif metric == "precision":
513
- value = float(
514
- precision_score(
515
- y_true,
516
- y_p_binary,
517
- average="samples" if task_type == "multilabel" else "binary",
518
- zero_division=0,
519
- )
520
- )
521
- elif metric == "recall":
522
- value = float(
523
- recall_score(
524
- y_true,
525
- y_p_binary,
526
- average="samples" if task_type == "multilabel" else "binary",
527
- zero_division=0,
528
- )
529
- )
530
- elif metric == "f1":
531
- value = float(
532
- f1_score(
533
- y_true,
534
- y_p_binary,
535
- average="samples" if task_type == "multilabel" else "binary",
536
- zero_division=0,
537
- )
538
- )
539
- elif metric == "micro_f1":
540
- value = float(
541
- f1_score(y_true, y_p_binary, average="micro", zero_division=0)
542
- )
543
- elif metric == "macro_f1":
544
- value = float(
545
- f1_score(y_true, y_p_binary, average="macro", zero_division=0)
546
- )
547
- elif metric == "mse":
456
+ elif metric == 'precision':
457
+ value = float(precision_score(y_true, y_p_binary, average='samples' if task_type == 'multilabel' else 'binary', zero_division=0))
458
+ elif metric == 'recall':
459
+ value = float(recall_score(y_true, y_p_binary, average='samples' if task_type == 'multilabel' else 'binary', zero_division=0))
460
+ elif metric == 'f1':
461
+ value = float(f1_score(y_true, y_p_binary, average='samples' if task_type == 'multilabel' else 'binary', zero_division=0))
462
+ elif metric == 'micro_f1':
463
+ value = float(f1_score(y_true, y_p_binary, average='micro', zero_division=0))
464
+ elif metric == 'macro_f1':
465
+ value = float(f1_score(y_true, y_p_binary, average='macro', zero_division=0))
466
+ elif metric == 'mse':
548
467
  value = float(mean_squared_error(y_true, y_pred))
549
- elif metric == "mae":
468
+ elif metric == 'mae':
550
469
  value = float(mean_absolute_error(y_true, y_pred))
551
- elif metric == "rmse":
470
+ elif metric == 'rmse':
552
471
  value = float(np.sqrt(mean_squared_error(y_true, y_pred)))
553
- elif metric == "r2":
472
+ elif metric == 'r2':
554
473
  value = float(r2_score(y_true, y_pred))
555
- elif metric == "mape":
474
+ elif metric == 'mape':
556
475
  value = float(compute_mape(y_true, y_pred))
557
- elif metric == "msle":
476
+ elif metric == 'msle':
558
477
  value = float(compute_msle(y_true, y_pred))
559
478
  else:
560
479
  logging.warning(f"Metric '{metric}' is not supported, returning 0.0")
@@ -562,40 +481,35 @@ def compute_single_metric(
562
481
  except Exception as exception:
563
482
  logging.warning(f"Failed to compute metric {metric}: {exception}")
564
483
  value = 0.0
565
-
484
+
566
485
  return value
567
486
 
568
-
569
487
  def evaluate_metrics(
570
488
  y_true: np.ndarray | None,
571
489
  y_pred: np.ndarray | None,
572
- metrics: list[str], # example: ['auc', 'logloss']
573
- task: str | list[str], # example: 'binary' or ['binary', 'regression']
574
- target_names: list[str], # example: ['target1', 'target2']
575
- task_specific_metrics: (
576
- dict[str, list[str]] | None
577
- ) = None, # example: {'target1': ['auc', 'logloss'], 'target2': ['mse']}
578
- user_ids: np.ndarray | None = None, # example: User IDs for GAUC computation
579
- ) -> dict: # {'auc': 0.75, 'logloss': 0.45, 'mse_target2': 3.2}
490
+ metrics: list[str], # example: ['auc', 'logloss']
491
+ task: str | list[str], # example: 'binary' or ['binary', 'regression']
492
+ target_names: list[str], # example: ['target1', 'target2']
493
+ task_specific_metrics: dict[str, list[str]] | None = None, # example: {'target1': ['auc', 'logloss'], 'target2': ['mse']}
494
+ user_ids: np.ndarray | None = None # example: User IDs for GAUC computation
495
+ ) -> dict: # {'auc': 0.75, 'logloss': 0.45, 'mse_target2': 3.2}
580
496
  """Evaluate specified metrics for given true and predicted values."""
581
497
  result = {}
582
-
498
+
583
499
  if y_true is None or y_pred is None:
584
500
  return result
585
-
501
+
586
502
  # Main evaluation logic
587
503
  primary_task = task[0] if isinstance(task, list) else task
588
504
  nums_task = len(task) if isinstance(task, list) else 1
589
-
505
+
590
506
  # Single task evaluation
591
507
  if nums_task == 1:
592
508
  for metric in metrics:
593
509
  metric_lower = metric.lower()
594
- value = compute_single_metric(
595
- metric_lower, y_true, y_pred, primary_task, user_ids
596
- )
510
+ value = compute_single_metric(metric_lower, y_true, y_pred, primary_task, user_ids)
597
511
  result[metric_lower] = value
598
-
512
+
599
513
  # Multi-task evaluation
600
514
  else:
601
515
  for metric in metrics:
@@ -605,9 +519,7 @@ def evaluate_metrics(
605
519
  should_compute = True
606
520
  if task_specific_metrics is not None and task_idx < len(target_names):
607
521
  task_name = target_names[task_idx]
608
- should_compute = metric_lower in task_specific_metrics.get(
609
- task_name, []
610
- )
522
+ should_compute = metric_lower in task_specific_metrics.get(task_name, [])
611
523
  else:
612
524
  # Get task type for specific index
613
525
  if isinstance(task, list) and task_idx < len(task):
@@ -615,51 +527,31 @@ def evaluate_metrics(
615
527
  elif isinstance(task, str):
616
528
  task_type = task
617
529
  else:
618
- task_type = "binary"
619
-
620
- if task_type in ["binary", "multilabel"]:
621
- should_compute = metric_lower in {
622
- "auc",
623
- "ks",
624
- "logloss",
625
- "accuracy",
626
- "acc",
627
- "precision",
628
- "recall",
629
- "f1",
630
- "micro_f1",
631
- "macro_f1",
632
- }
633
- elif task_type == "regression":
634
- should_compute = metric_lower in {
635
- "mse",
636
- "mae",
637
- "rmse",
638
- "r2",
639
- "mape",
640
- "msle",
641
- }
642
-
530
+ task_type = 'binary'
531
+
532
+ if task_type in ['binary', 'multilabel']:
533
+ should_compute = metric_lower in {'auc', 'ks', 'logloss', 'accuracy', 'acc', 'precision', 'recall', 'f1', 'micro_f1', 'macro_f1'}
534
+ elif task_type == 'regression':
535
+ should_compute = metric_lower in {'mse', 'mae', 'rmse', 'r2', 'mape', 'msle'}
536
+
643
537
  if not should_compute:
644
538
  continue
645
-
539
+
646
540
  target_name = target_names[task_idx]
647
-
541
+
648
542
  # Get task type for specific index
649
543
  if isinstance(task, list) and task_idx < len(task):
650
544
  task_type = task[task_idx]
651
545
  elif isinstance(task, str):
652
546
  task_type = task
653
547
  else:
654
- task_type = "binary"
655
-
548
+ task_type = 'binary'
549
+
656
550
  y_true_task = y_true[:, task_idx]
657
551
  y_pred_task = y_pred[:, task_idx]
658
-
552
+
659
553
  # Compute metric
660
- value = compute_single_metric(
661
- metric_lower, y_true_task, y_pred_task, task_type, user_ids
662
- )
663
- result[f"{metric_lower}_{target_name}"] = value
664
-
554
+ value = compute_single_metric(metric_lower, y_true_task, y_pred_task, task_type, user_ids)
555
+ result[f'{metric_lower}_{target_name}'] = value
556
+
665
557
  return result