nextrec 0.2.6__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/activation.py +4 -8
  3. nextrec/basic/callback.py +1 -1
  4. nextrec/basic/features.py +33 -25
  5. nextrec/basic/layers.py +164 -601
  6. nextrec/basic/loggers.py +3 -4
  7. nextrec/basic/metrics.py +39 -115
  8. nextrec/basic/model.py +248 -174
  9. nextrec/basic/session.py +1 -5
  10. nextrec/data/__init__.py +12 -0
  11. nextrec/data/data_utils.py +3 -27
  12. nextrec/data/dataloader.py +26 -34
  13. nextrec/data/preprocessor.py +2 -1
  14. nextrec/loss/listwise.py +6 -4
  15. nextrec/loss/loss_utils.py +10 -6
  16. nextrec/loss/pairwise.py +5 -3
  17. nextrec/loss/pointwise.py +7 -13
  18. nextrec/models/match/mind.py +110 -1
  19. nextrec/models/multi_task/esmm.py +46 -27
  20. nextrec/models/multi_task/mmoe.py +48 -30
  21. nextrec/models/multi_task/ple.py +156 -141
  22. nextrec/models/multi_task/poso.py +413 -0
  23. nextrec/models/multi_task/share_bottom.py +43 -26
  24. nextrec/models/ranking/__init__.py +2 -0
  25. nextrec/models/ranking/autoint.py +1 -1
  26. nextrec/models/ranking/dcn.py +20 -1
  27. nextrec/models/ranking/dcn_v2.py +84 -0
  28. nextrec/models/ranking/deepfm.py +44 -18
  29. nextrec/models/ranking/dien.py +130 -27
  30. nextrec/models/ranking/masknet.py +13 -67
  31. nextrec/models/ranking/widedeep.py +39 -18
  32. nextrec/models/ranking/xdeepfm.py +34 -1
  33. nextrec/utils/common.py +26 -1
  34. nextrec-0.3.1.dist-info/METADATA +306 -0
  35. nextrec-0.3.1.dist-info/RECORD +56 -0
  36. {nextrec-0.2.6.dist-info → nextrec-0.3.1.dist-info}/WHEEL +1 -1
  37. nextrec-0.2.6.dist-info/METADATA +0 -281
  38. nextrec-0.2.6.dist-info/RECORD +0 -54
  39. {nextrec-0.2.6.dist-info → nextrec-0.3.1.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/loggers.py CHANGED
@@ -2,7 +2,8 @@
2
2
  NextRec Basic Loggers
3
3
 
4
4
  Date: create on 27/10/2025
5
- Author: Yang Zhou,zyaztec@gmail.com
5
+ Checkpoint: edit on 29/11/2025
6
+ Author: Yang Zhou, zyaztec@gmail.com
6
7
  """
7
8
 
8
9
 
@@ -10,10 +11,8 @@ import os
10
11
  import re
11
12
  import sys
12
13
  import copy
13
- import datetime
14
14
  import logging
15
- from pathlib import Path
16
- from nextrec.basic.session import resolve_save_path, create_session
15
+ from nextrec.basic.session import create_session
17
16
 
18
17
  ANSI_CODES = {
19
18
  'black': '\033[30m',
nextrec/basic/metrics.py CHANGED
@@ -2,6 +2,7 @@
2
2
  Metrics computation and configuration for model evaluation.
3
3
 
4
4
  Date: create on 27/10/2025
5
+ Checkpoint: edit on 29/11/2025
5
6
  Author: Yang Zhou,zyaztec@gmail.com
6
7
  """
7
8
  import logging
@@ -11,7 +12,6 @@ from sklearn.metrics import (
11
12
  accuracy_score, precision_score, recall_score, f1_score, r2_score,
12
13
  )
13
14
 
14
-
15
15
  CLASSIFICATION_METRICS = {'auc', 'gauc', 'ks', 'logloss', 'accuracy', 'acc', 'precision', 'recall', 'f1', 'micro_f1', 'macro_f1'}
16
16
  REGRESSION_METRICS = {'mse', 'mae', 'rmse', 'r2', 'mape', 'msle'}
17
17
  TASK_DEFAULT_METRICS = {
@@ -21,8 +21,6 @@ TASK_DEFAULT_METRICS = {
21
21
  'matching': ['auc', 'gauc', 'precision@10', 'hitrate@10', 'map@10','cosine']+ [f'recall@{k}' for k in (5,10,20)] + [f'ndcg@{k}' for k in (5,10,20)] + [f'mrr@{k}' for k in (5,10,20)]
22
22
  }
23
23
 
24
-
25
-
26
24
  def compute_ks(y_true: np.ndarray, y_pred: np.ndarray) -> float:
27
25
  """Compute Kolmogorov-Smirnov statistic."""
28
26
  sorted_indices = np.argsort(y_pred)[::-1]
@@ -38,7 +36,6 @@ def compute_ks(y_true: np.ndarray, y_pred: np.ndarray) -> float:
38
36
  return float(ks_value)
39
37
  return 0.0
40
38
 
41
-
42
39
  def compute_mape(y_true: np.ndarray, y_pred: np.ndarray) -> float:
43
40
  """Compute Mean Absolute Percentage Error."""
44
41
  mask = y_true != 0
@@ -46,83 +43,62 @@ def compute_mape(y_true: np.ndarray, y_pred: np.ndarray) -> float:
46
43
  return float(np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100)
47
44
  return 0.0
48
45
 
49
-
50
46
  def compute_msle(y_true: np.ndarray, y_pred: np.ndarray) -> float:
51
47
  """Compute Mean Squared Log Error."""
52
48
  y_pred_pos = np.maximum(y_pred, 0)
53
49
  return float(mean_squared_error(np.log1p(y_true), np.log1p(y_pred_pos)))
54
50
 
55
-
56
- def compute_gauc(
57
- y_true: np.ndarray,
58
- y_pred: np.ndarray,
59
- user_ids: np.ndarray | None = None
60
- ) -> float:
51
+ def compute_gauc(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None = None) -> float:
61
52
  if user_ids is None:
62
53
  # If no user_ids provided, fall back to regular AUC
63
54
  try:
64
55
  return float(roc_auc_score(y_true, y_pred))
65
56
  except:
66
57
  return 0.0
67
-
68
58
  # Group by user_id and calculate AUC for each user
69
59
  user_aucs = []
70
60
  user_weights = []
71
-
72
61
  unique_users = np.unique(user_ids)
73
-
74
62
  for user_id in unique_users:
75
63
  mask = user_ids == user_id
76
64
  user_y_true = y_true[mask]
77
65
  user_y_pred = y_pred[mask]
78
-
79
66
  # Skip users with only one class (cannot compute AUC)
80
67
  if len(np.unique(user_y_true)) < 2:
81
68
  continue
82
-
83
69
  try:
84
70
  user_auc = roc_auc_score(user_y_true, user_y_pred)
85
71
  user_aucs.append(user_auc)
86
72
  user_weights.append(len(user_y_true))
87
73
  except:
88
74
  continue
89
-
90
75
  if len(user_aucs) == 0:
91
76
  return 0.0
92
-
93
77
  # Weighted average
94
78
  user_aucs = np.array(user_aucs)
95
79
  user_weights = np.array(user_weights)
96
80
  gauc = float(np.sum(user_aucs * user_weights) / np.sum(user_weights))
97
-
98
81
  return gauc
99
82
 
100
-
101
83
  def _group_indices_by_user(user_ids: np.ndarray, n_samples: int) -> list[np.ndarray]:
102
84
  """Group sample indices by user_id. If user_ids is None, treat all as one group."""
103
85
  if user_ids is None:
104
86
  return [np.arange(n_samples)]
105
-
106
87
  user_ids = np.asarray(user_ids)
107
88
  if user_ids.shape[0] != n_samples:
108
- logging.warning(
109
- "user_ids length (%d) != number of samples (%d), "
110
- "treating all samples as a single group for ranking metrics.",
111
- user_ids.shape[0],
112
- n_samples,
113
- )
89
+ logging.warning(f"[Metrics Warning: GAUC] user_ids length {user_ids.shape[0]} != number of samples {n_samples}, treating all samples as a single group for ranking metrics.")
114
90
  return [np.arange(n_samples)]
115
-
116
91
  unique_users = np.unique(user_ids)
117
92
  groups = [np.where(user_ids == u)[0] for u in unique_users]
118
93
  return groups
119
94
 
120
-
121
- def _compute_precision_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
95
+ def _compute_precision_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
96
+ """Compute Precision@K."""
97
+ if user_ids is None:
98
+ raise ValueError("[Metrics Error: Precision@K] user_ids must be provided for Precision@K computation.")
122
99
  y_true = (y_true > 0).astype(int)
123
100
  n = len(y_true)
124
101
  groups = _group_indices_by_user(user_ids, n)
125
-
126
102
  precisions = []
127
103
  for idx in groups:
128
104
  if idx.size == 0:
@@ -134,16 +110,15 @@ def _compute_precision_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np
134
110
  topk = order[:k_user]
135
111
  hits = labels[topk].sum()
136
112
  precisions.append(hits / float(k_user))
137
-
138
113
  return float(np.mean(precisions)) if precisions else 0.0
139
114
 
140
-
141
- def _compute_recall_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
115
+ def _compute_recall_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
142
116
  """Compute Recall@K."""
117
+ if user_ids is None:
118
+ raise ValueError("[Metrics Error: Recall@K] user_ids must be provided for Recall@K computation.")
143
119
  y_true = (y_true > 0).astype(int)
144
120
  n = len(y_true)
145
121
  groups = _group_indices_by_user(user_ids, n)
146
-
147
122
  recalls = []
148
123
  for idx in groups:
149
124
  if idx.size == 0:
@@ -151,46 +126,44 @@ def _compute_recall_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.nd
151
126
  labels = y_true[idx]
152
127
  num_pos = labels.sum()
153
128
  if num_pos == 0:
154
- continue # 跳过没有正样本的用户
129
+ continue # dont count users with no positive labels
155
130
  scores = y_pred[idx]
156
131
  order = np.argsort(scores)[::-1]
157
132
  k_user = min(k, idx.size)
158
133
  topk = order[:k_user]
159
134
  hits = labels[topk].sum()
160
135
  recalls.append(hits / float(num_pos))
161
-
162
136
  return float(np.mean(recalls)) if recalls else 0.0
163
137
 
164
-
165
- def _compute_hitrate_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
138
+ def _compute_hitrate_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
166
139
  """Compute HitRate@K."""
140
+ if user_ids is None:
141
+ raise ValueError("[Metrics Error: HitRate@K] user_ids must be provided for HitRate@K computation.")
167
142
  y_true = (y_true > 0).astype(int)
168
143
  n = len(y_true)
169
144
  groups = _group_indices_by_user(user_ids, n)
170
-
171
145
  hits_per_user = []
172
146
  for idx in groups:
173
147
  if idx.size == 0:
174
148
  continue
175
149
  labels = y_true[idx]
176
150
  if labels.sum() == 0:
177
- continue # 无正样本用户不计入
151
+ continue # dont count users with no positive labels
178
152
  scores = y_pred[idx]
179
153
  order = np.argsort(scores)[::-1]
180
154
  k_user = min(k, idx.size)
181
155
  topk = order[:k_user]
182
156
  hits = labels[topk].sum()
183
157
  hits_per_user.append(1.0 if hits > 0 else 0.0)
184
-
185
158
  return float(np.mean(hits_per_user)) if hits_per_user else 0.0
186
159
 
187
-
188
- def _compute_mrr_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
160
+ def _compute_mrr_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
189
161
  """Compute MRR@K."""
162
+ if user_ids is None:
163
+ raise ValueError("[Metrics Error: MRR@K] user_ids must be provided for MRR@K computation.")
190
164
  y_true = (y_true > 0).astype(int)
191
165
  n = len(y_true)
192
166
  groups = _group_indices_by_user(user_ids, n)
193
-
194
167
  mrrs = []
195
168
  for idx in groups:
196
169
  if idx.size == 0:
@@ -203,17 +176,14 @@ def _compute_mrr_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarr
203
176
  k_user = min(k, idx.size)
204
177
  topk = order[:k_user]
205
178
  ranked_labels = labels[order]
206
-
207
179
  rr = 0.0
208
180
  for rank, lab in enumerate(ranked_labels[:k_user], start=1):
209
181
  if lab > 0:
210
182
  rr = 1.0 / rank
211
183
  break
212
184
  mrrs.append(rr)
213
-
214
185
  return float(np.mean(mrrs)) if mrrs else 0.0
215
186
 
216
-
217
187
  def _compute_dcg_at_k(labels: np.ndarray, k: int) -> float:
218
188
  k_user = min(k, labels.size)
219
189
  if k_user == 0:
@@ -222,13 +192,13 @@ def _compute_dcg_at_k(labels: np.ndarray, k: int) -> float:
222
192
  discounts = np.log2(np.arange(2, k_user + 2))
223
193
  return float(np.sum(gains / discounts))
224
194
 
225
-
226
- def _compute_ndcg_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
195
+ def _compute_ndcg_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
227
196
  """Compute NDCG@K."""
197
+ if user_ids is None:
198
+ raise ValueError("[Metrics Error: NDCG@K] user_ids must be provided for NDCG@K computation.")
228
199
  y_true = (y_true > 0).astype(int)
229
200
  n = len(y_true)
230
201
  groups = _group_indices_by_user(user_ids, n)
231
-
232
202
  ndcgs = []
233
203
  for idx in groups:
234
204
  if idx.size == 0:
@@ -237,27 +207,25 @@ def _compute_ndcg_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndar
237
207
  if labels.sum() == 0:
238
208
  continue
239
209
  scores = y_pred[idx]
240
-
241
210
  order = np.argsort(scores)[::-1]
242
211
  ranked_labels = labels[order]
243
212
  dcg = _compute_dcg_at_k(ranked_labels, k)
244
-
245
213
  # ideal DCG
246
214
  ideal_labels = np.sort(labels)[::-1]
247
215
  idcg = _compute_dcg_at_k(ideal_labels, k)
248
216
  if idcg == 0.0:
249
217
  continue
250
218
  ndcgs.append(dcg / idcg)
251
-
252
219
  return float(np.mean(ndcgs)) if ndcgs else 0.0
253
220
 
254
221
 
255
- def _compute_map_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray | None, k: int) -> float:
222
+ def _compute_map_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
256
223
  """Mean Average Precision@K."""
224
+ if user_ids is None:
225
+ raise ValueError("[Metrics Error: MAP@K] user_ids must be provided for MAP@K computation.")
257
226
  y_true = (y_true > 0).astype(int)
258
227
  n = len(y_true)
259
228
  groups = _group_indices_by_user(user_ids, n)
260
-
261
229
  aps = []
262
230
  for idx in groups:
263
231
  if idx.size == 0:
@@ -266,23 +234,19 @@ def _compute_map_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarr
266
234
  num_pos = labels.sum()
267
235
  if num_pos == 0:
268
236
  continue
269
-
270
237
  scores = y_pred[idx]
271
238
  order = np.argsort(scores)[::-1]
272
239
  k_user = min(k, idx.size)
273
-
274
240
  hits = 0
275
241
  sum_precisions = 0.0
276
242
  for rank, i in enumerate(order[:k_user], start=1):
277
243
  if labels[i] > 0:
278
244
  hits += 1
279
245
  sum_precisions += hits / float(rank)
280
-
281
246
  if hits == 0:
282
247
  aps.append(0.0)
283
248
  else:
284
249
  aps.append(sum_precisions / float(num_pos))
285
-
286
250
  return float(np.mean(aps)) if aps else 0.0
287
251
 
288
252
 
@@ -308,31 +272,22 @@ def configure_metrics(
308
272
  """Configure metrics based on task and user input."""
309
273
  primary_task = task[0] if isinstance(task, list) else task
310
274
  nums_task = len(task) if isinstance(task, list) else 1
311
-
312
275
  metrics_list: list[str] = []
313
276
  task_specific_metrics: dict[str, list[str]] | None = None
314
-
315
277
  if isinstance(metrics, dict):
316
278
  metrics_list = []
317
279
  task_specific_metrics = {}
318
280
  for task_name, task_metrics in metrics.items():
319
281
  if task_name not in target_names:
320
- logging.warning(
321
- "Task '%s' not found in targets %s, skipping its metrics",
322
- task_name,
323
- target_names,
324
- )
282
+ logging.warning(f"[Metrics Warning] Task {task_name} not found in targets {target_names}, skipping its metrics")
325
283
  continue
326
-
327
284
  lowered = [m.lower() for m in task_metrics]
328
285
  task_specific_metrics[task_name] = lowered
329
286
  for metric in lowered:
330
287
  if metric not in metrics_list:
331
288
  metrics_list.append(metric)
332
-
333
289
  elif metrics:
334
290
  metrics_list = [m.lower() for m in metrics]
335
-
336
291
  else:
337
292
  # No user provided metrics, derive per task type
338
293
  if nums_task > 1 and isinstance(task, list):
@@ -350,26 +305,20 @@ def configure_metrics(
350
305
  if primary_task not in TASK_DEFAULT_METRICS:
351
306
  raise ValueError(f"Unsupported task type: {primary_task}")
352
307
  metrics_list = TASK_DEFAULT_METRICS[primary_task]
353
-
354
308
  if not metrics_list:
355
309
  # Inline get_default_metrics_for_task logic
356
310
  if primary_task not in TASK_DEFAULT_METRICS:
357
311
  raise ValueError(f"Unsupported task type: {primary_task}")
358
312
  metrics_list = TASK_DEFAULT_METRICS[primary_task]
359
-
360
313
  best_metrics_mode = get_best_metric_mode(metrics_list[0], primary_task)
361
-
362
314
  return metrics_list, task_specific_metrics, best_metrics_mode
363
315
 
364
-
365
316
  def get_best_metric_mode(first_metric: str, primary_task: str) -> str:
366
317
  """Determine if metric should be maximized or minimized."""
367
318
  first_metric_lower = first_metric.lower()
368
-
369
319
  # Metrics that should be maximized
370
320
  if first_metric_lower in {'auc', 'gauc', 'ks', 'accuracy', 'acc', 'precision', 'recall', 'f1', 'r2', 'micro_f1', 'macro_f1'}:
371
321
  return 'max'
372
-
373
322
  # Ranking metrics that should be maximized (with @K suffix)
374
323
  if (first_metric_lower.startswith('recall@') or
375
324
  first_metric_lower.startswith('precision@') or
@@ -379,21 +328,17 @@ def get_best_metric_mode(first_metric: str, primary_task: str) -> str:
379
328
  first_metric_lower.startswith('ndcg@') or
380
329
  first_metric_lower.startswith('map@')):
381
330
  return 'max'
382
-
383
331
  # Cosine separation should be maximized
384
332
  if first_metric_lower == 'cosine':
385
333
  return 'max'
386
-
387
334
  # Metrics that should be minimized
388
335
  if first_metric_lower in {'logloss', 'mse', 'mae', 'rmse', 'mape', 'msle'}:
389
336
  return 'min'
390
-
391
337
  # Default based on task type
392
338
  if primary_task == 'regression':
393
339
  return 'min'
394
340
  return 'max'
395
341
 
396
-
397
342
  def compute_single_metric(
398
343
  metric: str,
399
344
  y_true: np.ndarray,
@@ -403,45 +348,36 @@ def compute_single_metric(
403
348
  ) -> float:
404
349
  """Compute a single metric given true and predicted values."""
405
350
  y_p_binary = (y_pred > 0.5).astype(int)
406
-
407
351
  try:
408
352
  metric_lower = metric.lower()
409
-
410
353
  # recall@K
411
354
  if metric_lower.startswith('recall@'):
412
355
  k = int(metric_lower.split('@')[1])
413
- return _compute_recall_at_k(y_true, y_pred, user_ids, k)
414
-
356
+ return _compute_recall_at_k(y_true, y_pred, user_ids, k) # type: ignore
415
357
  # precision@K
416
358
  if metric_lower.startswith('precision@'):
417
359
  k = int(metric_lower.split('@')[1])
418
- return _compute_precision_at_k(y_true, y_pred, user_ids, k)
419
-
360
+ return _compute_precision_at_k(y_true, y_pred, user_ids, k) # type: ignore
420
361
  # hitrate@K / hr@K
421
362
  if metric_lower.startswith('hitrate@') or metric_lower.startswith('hr@'):
422
363
  k_str = metric_lower.split('@')[1]
423
364
  k = int(k_str)
424
- return _compute_hitrate_at_k(y_true, y_pred, user_ids, k)
425
-
365
+ return _compute_hitrate_at_k(y_true, y_pred, user_ids, k) # type: ignore
426
366
  # mrr@K
427
367
  if metric_lower.startswith('mrr@'):
428
368
  k = int(metric_lower.split('@')[1])
429
- return _compute_mrr_at_k(y_true, y_pred, user_ids, k)
430
-
369
+ return _compute_mrr_at_k(y_true, y_pred, user_ids, k) # type: ignore
431
370
  # ndcg@K
432
371
  if metric_lower.startswith('ndcg@'):
433
372
  k = int(metric_lower.split('@')[1])
434
- return _compute_ndcg_at_k(y_true, y_pred, user_ids, k)
435
-
373
+ return _compute_ndcg_at_k(y_true, y_pred, user_ids, k) # type: ignore
436
374
  # map@K
437
375
  if metric_lower.startswith('map@'):
438
376
  k = int(metric_lower.split('@')[1])
439
- return _compute_map_at_k(y_true, y_pred, user_ids, k)
440
-
377
+ return _compute_map_at_k(y_true, y_pred, user_ids, k) # type: ignore
441
378
  # cosine for matching task
442
379
  if metric_lower == 'cosine':
443
380
  return _compute_cosine_separation(y_true, y_pred)
444
-
445
381
  if metric == 'auc':
446
382
  value = float(roc_auc_score(y_true, y_pred, average='macro' if task_type == 'multilabel' else None))
447
383
  elif metric == 'gauc':
@@ -475,12 +411,11 @@ def compute_single_metric(
475
411
  elif metric == 'msle':
476
412
  value = float(compute_msle(y_true, y_pred))
477
413
  else:
478
- logging.warning(f"Metric '{metric}' is not supported, returning 0.0")
414
+ logging.warning(f"[Metric Warning] Metric '{metric}' is not supported, returning 0.0")
479
415
  value = 0.0
480
416
  except Exception as exception:
481
- logging.warning(f"Failed to compute metric {metric}: {exception}")
417
+ logging.warning(f"[Metric Warning] Failed to compute metric {metric}: {exception}")
482
418
  value = 0.0
483
-
484
419
  return value
485
420
 
486
421
  def evaluate_metrics(
@@ -494,21 +429,17 @@ def evaluate_metrics(
494
429
  ) -> dict: # {'auc': 0.75, 'logloss': 0.45, 'mse_target2': 3.2}
495
430
  """Evaluate specified metrics for given true and predicted values."""
496
431
  result = {}
497
-
498
432
  if y_true is None or y_pred is None:
499
433
  return result
500
-
501
434
  # Main evaluation logic
502
435
  primary_task = task[0] if isinstance(task, list) else task
503
436
  nums_task = len(task) if isinstance(task, list) else 1
504
-
505
437
  # Single task evaluation
506
438
  if nums_task == 1:
507
439
  for metric in metrics:
508
440
  metric_lower = metric.lower()
509
441
  value = compute_single_metric(metric_lower, y_true, y_pred, primary_task, user_ids)
510
442
  result[metric_lower] = value
511
-
512
443
  # Multi-task evaluation
513
444
  else:
514
445
  for metric in metrics:
@@ -526,31 +457,24 @@ def evaluate_metrics(
526
457
  elif isinstance(task, str):
527
458
  task_type = task
528
459
  else:
529
- task_type = 'binary'
530
-
460
+ task_type = 'binary'
531
461
  if task_type in ['binary', 'multilabel']:
532
462
  should_compute = metric_lower in {'auc', 'ks', 'logloss', 'accuracy', 'acc', 'precision', 'recall', 'f1', 'micro_f1', 'macro_f1'}
533
463
  elif task_type == 'regression':
534
- should_compute = metric_lower in {'mse', 'mae', 'rmse', 'r2', 'mape', 'msle'}
535
-
464
+ should_compute = metric_lower in {'mse', 'mae', 'rmse', 'r2', 'mape', 'msle'}
536
465
  if not should_compute:
537
- continue
538
-
539
- target_name = target_names[task_idx]
540
-
466
+ continue
467
+ target_name = target_names[task_idx]
541
468
  # Get task type for specific index
542
469
  if isinstance(task, list) and task_idx < len(task):
543
470
  task_type = task[task_idx]
544
471
  elif isinstance(task, str):
545
472
  task_type = task
546
473
  else:
547
- task_type = 'binary'
548
-
474
+ task_type = 'binary'
549
475
  y_true_task = y_true[:, task_idx]
550
- y_pred_task = y_pred[:, task_idx]
551
-
476
+ y_pred_task = y_pred[:, task_idx]
552
477
  # Compute metric
553
478
  value = compute_single_metric(metric_lower, y_true_task, y_pred_task, task_type, user_ids)
554
479
  result[f'{metric_lower}_{target_name}'] = value
555
-
556
480
  return result