nextrec 0.4.10__py3-none-any.whl → 0.4.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.10"
1
+ __version__ = "0.4.12"
nextrec/basic/callback.py CHANGED
@@ -22,10 +22,10 @@ class Callback:
22
22
  """
23
23
  Base callback.
24
24
 
25
- Notes (DDP):
26
- - In distributed training, the training loop runs on every rank.
27
- - For callbacks with side effects (saving, logging, etc.), set
28
- ``run_on_main_process_only=True`` to avoid multi-rank duplication.
25
+ Notes for DDP training:
26
+ In distributed training, the training loop runs on every rank.
27
+ For callbacks with side effects (saving, logging, etc.), set
28
+ ``run_on_main_process_only=True`` to avoid multi-rank duplication.
29
29
  """
30
30
 
31
31
  run_on_main_process_only: bool = False
@@ -70,7 +70,7 @@ class Callback:
70
70
 
71
71
 
72
72
  class CallbackList:
73
- """Container for managing multiple callbacks."""
73
+ """Generates a list of callbacks"""
74
74
 
75
75
  def __init__(self, callbacks: Optional[list[Callback]] = None):
76
76
  self.callbacks = callbacks or []
@@ -78,61 +78,41 @@ class CallbackList:
78
78
  def append(self, callback: Callback):
79
79
  self.callbacks.append(callback)
80
80
 
81
- def set_model(self, model):
81
+ def call(self, fn_name: str, *args, **kwargs):
82
82
  for callback in self.callbacks:
83
- callback.set_model(model)
83
+ if not callback.should_run():
84
+ continue
85
+ getattr(callback, fn_name)(*args, **kwargs)
86
+
87
+ def set_model(self, model):
88
+ self.call("set_model", model)
84
89
 
85
90
  def set_params(self, params: dict):
86
- for callback in self.callbacks:
87
- callback.set_params(params)
91
+ self.call("set_params", params)
88
92
 
89
93
  def on_train_begin(self, logs: Optional[dict] = None):
90
- for callback in self.callbacks:
91
- if not callback.should_run():
92
- continue
93
- callback.on_train_begin(logs)
94
+ self.call("on_train_begin", logs)
94
95
 
95
96
  def on_train_end(self, logs: Optional[dict] = None):
96
- for callback in self.callbacks:
97
- if not callback.should_run():
98
- continue
99
- callback.on_train_end(logs)
97
+ self.call("on_train_end", logs)
100
98
 
101
99
  def on_epoch_begin(self, epoch: int, logs: Optional[dict] = None):
102
- for callback in self.callbacks:
103
- if not callback.should_run():
104
- continue
105
- callback.on_epoch_begin(epoch, logs)
100
+ self.call("on_epoch_begin", epoch, logs)
106
101
 
107
102
  def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
108
- for callback in self.callbacks:
109
- if not callback.should_run():
110
- continue
111
- callback.on_epoch_end(epoch, logs)
103
+ self.call("on_epoch_end", epoch, logs)
112
104
 
113
105
  def on_batch_begin(self, batch: int, logs: Optional[dict] = None):
114
- for callback in self.callbacks:
115
- if not callback.should_run():
116
- continue
117
- callback.on_batch_begin(batch, logs)
106
+ self.call("on_batch_begin", batch, logs)
118
107
 
119
108
  def on_batch_end(self, batch: int, logs: Optional[dict] = None):
120
- for callback in self.callbacks:
121
- if not callback.should_run():
122
- continue
123
- callback.on_batch_end(batch, logs)
109
+ self.call("on_batch_end", batch, logs)
124
110
 
125
111
  def on_validation_begin(self, logs: Optional[dict] = None):
126
- for callback in self.callbacks:
127
- if not callback.should_run():
128
- continue
129
- callback.on_validation_begin(logs)
112
+ self.call("on_validation_begin", logs)
130
113
 
131
114
  def on_validation_end(self, logs: Optional[dict] = None):
132
- for callback in self.callbacks:
133
- if not callback.should_run():
134
- continue
135
- callback.on_validation_end(logs)
115
+ self.call("on_validation_end", logs)
136
116
 
137
117
 
138
118
  class EarlyStopper(Callback):
@@ -146,6 +126,20 @@ class EarlyStopper(Callback):
146
126
  restore_best_weights: bool = True,
147
127
  verbose: int = 1,
148
128
  ):
129
+ """
130
+ Callback to stop training early if no improvement.
131
+
132
+ Args:
133
+ monitor: Metric name to monitor.
134
+ patience: Number of epochs with no improvement after which training will be stopped.
135
+ mode: One of {'min', 'max'}. In 'min' mode, training will stop when the
136
+ monitored metric has stopped decreasing; in 'max' mode it will stop
137
+ when the monitored metric has stopped increasing.
138
+ min_delta: Minimum change in the monitored metric to qualify as an improvement.
139
+ restore_best_weights: Whether to restore model weights from the epoch with the best value
140
+ of the monitored metric.
141
+ verbose: Verbosity mode. 1: messages will be printed. 0: silent.
142
+ """
149
143
  super().__init__()
150
144
  self.monitor = monitor
151
145
  self.patience = patience
@@ -233,6 +227,7 @@ class CheckpointSaver(Callback):
233
227
  save_best_only: If True, only save when the model is considered the "best".
234
228
  save_freq: Frequency of checkpoint saving ('epoch' or integer for every N epochs).
235
229
  verbose: Verbosity mode.
230
+ run_on_main_process_only: Whether to run this callback only on the main process in DDP.
236
231
  """
237
232
 
238
233
  def __init__(
@@ -274,7 +269,6 @@ class CheckpointSaver(Callback):
274
269
  self.checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
275
270
 
276
271
  def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
277
- logging.info("")
278
272
  logs = logs or {}
279
273
 
280
274
  should_save = False
@@ -283,9 +277,6 @@ class CheckpointSaver(Callback):
283
277
  elif isinstance(self.save_freq, int) and (epoch + 1) % self.save_freq == 0:
284
278
  should_save = True
285
279
 
286
- if not should_save and self.save_best_only:
287
- should_save = False
288
-
289
280
  # Check if this is the best model
290
281
  current = logs.get(self.monitor)
291
282
  is_best = False
@@ -297,11 +288,7 @@ class CheckpointSaver(Callback):
297
288
 
298
289
  if should_save:
299
290
  if not self.save_best_only or is_best:
300
- checkpoint_path = (
301
- self.checkpoint_path.parent
302
- / f"{self.checkpoint_path.stem}{self.checkpoint_path.suffix}"
303
- )
304
- self.save_checkpoint(checkpoint_path, epoch, logs)
291
+ self.save_checkpoint(self.checkpoint_path, epoch, logs)
305
292
 
306
293
  if is_best:
307
294
  # Use save_path directly without adding _best suffix since it may already contain it
@@ -371,7 +358,9 @@ class LearningRateScheduler(Callback):
371
358
  # Step the scheduler
372
359
  if hasattr(self.scheduler, "step"):
373
360
  # Some schedulers need metrics
374
- if "val_loss" in (logs or {}) and hasattr(self.scheduler, "mode"):
361
+ if logs is None:
362
+ logs = {}
363
+ if "val_loss" in logs and hasattr(self.scheduler, "mode"):
375
364
  self.scheduler.step(logs["val_loss"])
376
365
  else:
377
366
  self.scheduler.step()
@@ -399,7 +388,6 @@ class MetricsLogger(Callback):
399
388
  self.run_on_main_process_only = True
400
389
  self.log_freq = log_freq
401
390
  self.verbose = verbose
402
- self.batch_count = 0
403
391
 
404
392
  def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
405
393
  if self.verbose > 0 and (
@@ -416,8 +404,10 @@ class MetricsLogger(Callback):
416
404
  logging.info(f"Epoch {epoch + 1}: {metrics_str}")
417
405
 
418
406
  def on_batch_end(self, batch: int, logs: Optional[dict] = None):
419
- self.batch_count += 1
420
- if self.verbose > 1 and self.log_freq == "batch":
407
+ if self.verbose > 1 and (
408
+ self.log_freq == "batch"
409
+ or (isinstance(self.log_freq, int) and (batch + 1) % self.log_freq == 0)
410
+ ):
421
411
  logs = logs or {}
422
412
  metrics_str = " - ".join(
423
413
  [
nextrec/basic/features.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Feature definitions
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 02/12/2025
5
+ Checkpoint: edit on 20/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -12,22 +12,20 @@ from nextrec.utils.embedding import get_auto_embedding_dim
12
12
  from nextrec.utils.feature import normalize_to_list
13
13
 
14
14
 
15
- class BaseFeature(object):
15
+ class BaseFeature:
16
16
  def __repr__(self):
17
17
  params = {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
18
18
  param_str = ", ".join(f"{k}={v!r}" for k, v in params.items())
19
19
  return f"{self.__class__.__name__}({param_str})"
20
20
 
21
21
 
22
- class SequenceFeature(BaseFeature):
22
+ class EmbeddingFeature(BaseFeature):
23
23
  def __init__(
24
24
  self,
25
25
  name: str,
26
26
  vocab_size: int,
27
- max_len: int = 20,
28
27
  embedding_name: str = "",
29
28
  embedding_dim: int | None = 4,
30
- combiner: str = "mean",
31
29
  padding_idx: int | None = None,
32
30
  init_type: str = "normal",
33
31
  init_params: dict | None = None,
@@ -39,13 +37,15 @@ class SequenceFeature(BaseFeature):
39
37
  ):
40
38
  self.name = name
41
39
  self.vocab_size = vocab_size
42
- self.max_len = max_len
43
40
  self.embedding_name = embedding_name or name
44
- self.embedding_dim = embedding_dim or get_auto_embedding_dim(vocab_size)
41
+ self.embedding_dim = (
42
+ get_auto_embedding_dim(vocab_size)
43
+ if embedding_dim is None
44
+ else embedding_dim
45
+ )
45
46
 
46
47
  self.init_type = init_type
47
48
  self.init_params = init_params or {}
48
- self.combiner = combiner
49
49
  self.padding_idx = padding_idx
50
50
  self.l1_reg = l1_reg
51
51
  self.l2_reg = l2_reg
@@ -54,13 +54,15 @@ class SequenceFeature(BaseFeature):
54
54
  self.freeze_pretrained = freeze_pretrained
55
55
 
56
56
 
57
- class SparseFeature(BaseFeature):
57
+ class SequenceFeature(EmbeddingFeature):
58
58
  def __init__(
59
59
  self,
60
60
  name: str,
61
61
  vocab_size: int,
62
+ max_len: int = 20,
62
63
  embedding_name: str = "",
63
64
  embedding_dim: int | None = 4,
65
+ combiner: str = "mean",
64
66
  padding_idx: int | None = None,
65
67
  init_type: str = "normal",
66
68
  init_params: dict | None = None,
@@ -70,19 +72,26 @@ class SparseFeature(BaseFeature):
70
72
  pretrained_weight: torch.Tensor | None = None,
71
73
  freeze_pretrained: bool = False,
72
74
  ):
73
- self.name = name
74
- self.vocab_size = vocab_size
75
- self.embedding_name = embedding_name or name
76
- self.embedding_dim = embedding_dim or get_auto_embedding_dim(vocab_size)
75
+ super().__init__(
76
+ name=name,
77
+ vocab_size=vocab_size,
78
+ embedding_name=embedding_name,
79
+ embedding_dim=embedding_dim,
80
+ padding_idx=padding_idx,
81
+ init_type=init_type,
82
+ init_params=init_params,
83
+ l1_reg=l1_reg,
84
+ l2_reg=l2_reg,
85
+ trainable=trainable,
86
+ pretrained_weight=pretrained_weight,
87
+ freeze_pretrained=freeze_pretrained,
88
+ )
89
+ self.max_len = max_len
90
+ self.combiner = combiner
77
91
 
78
- self.init_type = init_type
79
- self.init_params = init_params or {}
80
- self.padding_idx = padding_idx
81
- self.l1_reg = l1_reg
82
- self.l2_reg = l2_reg
83
- self.trainable = trainable
84
- self.pretrained_weight = pretrained_weight
85
- self.freeze_pretrained = freeze_pretrained
92
+
93
+ class SparseFeature(EmbeddingFeature):
94
+ pass
86
95
 
87
96
 
88
97
  class DenseFeature(BaseFeature):
@@ -95,7 +104,11 @@ class DenseFeature(BaseFeature):
95
104
  ):
96
105
  self.name = name
97
106
  self.input_dim = max(int(input_dim or 1), 1)
98
- self.embedding_dim = embedding_dim or self.input_dim
107
+ self.embedding_dim = self.input_dim if embedding_dim is None else embedding_dim
108
+ if use_embedding and self.embedding_dim == 0:
109
+ raise ValueError(
110
+ "[Features Error] DenseFeature: use_embedding=True is incompatible with embedding_dim=0"
111
+ )
99
112
  if embedding_dim is not None and embedding_dim > 1:
100
113
  self.use_embedding = True
101
114
  else:
nextrec/basic/layers.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Layer implementations used across NextRec models.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 19/12/2025
5
+ Checkpoint: edit on 20/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -28,6 +28,16 @@ class PredictionLayer(nn.Module):
28
28
  use_bias: bool = True,
29
29
  return_logits: bool = False,
30
30
  ):
31
+ """
32
+ Prediction layer supporting binary and regression outputs.
33
+
34
+ Args:
35
+ task_type: A string or list of strings specifying the type of each task. supported types are "binary" and "regression".
36
+ task_dims: An integer or list of integers specifying the output dimension for each task.
37
+ If None, defaults to 1 for each task. If a single integer is provided, it is shared across all tasks.
38
+ use_bias: Whether to include a bias term in the prediction layer.
39
+ return_logits: If True, returns raw logits without applying activation functions.
40
+ """
31
41
  super().__init__()
32
42
  self.task_types = [task_type] if isinstance(task_type, str) else list(task_type)
33
43
  if len(self.task_types) == 0:
@@ -253,8 +263,11 @@ class EmbeddingLayer(nn.Module):
253
263
  for feat in unique_feats.values():
254
264
  if isinstance(feat, DenseFeature):
255
265
  in_dim = max(int(getattr(feat, "input_dim", 1)), 1)
256
- emb_dim = getattr(feat, "embedding_dim", None)
257
- out_dim = max(int(emb_dim), 1) if emb_dim else in_dim
266
+ if getattr(feat, "use_embedding", False):
267
+ emb_dim = getattr(feat, "embedding_dim", None)
268
+ out_dim = max(int(emb_dim), 1) if emb_dim else in_dim
269
+ else:
270
+ out_dim = in_dim
258
271
  dim += out_dim
259
272
  elif isinstance(feat, SequenceFeature) and feat.combiner == "concat":
260
273
  dim += feat.embedding_dim * feat.max_len
@@ -518,13 +531,17 @@ class MultiHeadSelfAttention(nn.Module):
518
531
  self.use_residual = use_residual
519
532
  self.dropout_rate = dropout
520
533
 
521
- self.W_Q = nn.Linear(embedding_dim, embedding_dim, bias=False)
522
- self.W_K = nn.Linear(embedding_dim, embedding_dim, bias=False)
523
- self.W_V = nn.Linear(embedding_dim, embedding_dim, bias=False)
524
- self.W_O = nn.Linear(embedding_dim, embedding_dim, bias=False)
534
+ self.W_Q = nn.Linear(
535
+ embedding_dim, embedding_dim, bias=False
536
+ ) # Query projection
537
+ self.W_K = nn.Linear(embedding_dim, embedding_dim, bias=False) # Key projection
538
+ self.W_V = nn.Linear(
539
+ embedding_dim, embedding_dim, bias=False
540
+ ) # Value projection
541
+ self.W_O = nn.Linear(
542
+ embedding_dim, embedding_dim, bias=False
543
+ ) # Output projection
525
544
 
526
- if self.use_residual:
527
- self.W_Res = nn.Linear(embedding_dim, embedding_dim, bias=False)
528
545
  if use_layer_norm:
529
546
  self.layer_norm = nn.LayerNorm(embedding_dim)
530
547
  else:
@@ -537,81 +554,60 @@ class MultiHeadSelfAttention(nn.Module):
537
554
  def forward(
538
555
  self, x: torch.Tensor, attention_mask: torch.Tensor | None = None
539
556
  ) -> torch.Tensor:
540
- """
541
- Args:
542
- x: [batch_size, seq_len, embedding_dim]
543
- attention_mask: [batch_size, seq_len] or [batch_size, seq_len, seq_len], boolean mask where True indicates valid positions
544
- Returns:
545
- output: [batch_size, seq_len, embedding_dim]
546
- """
547
- batch_size, seq_len, _ = x.shape
548
- Q = self.W_Q(x) # [batch_size, seq_len, embedding_dim]
557
+ # x: [Batch, Length, Dim]
558
+ B, L, D = x.shape
559
+
560
+ Q = self.W_Q(x)
549
561
  K = self.W_K(x)
550
562
  V = self.W_V(x)
551
563
 
552
- # Split into multiple heads: [batch_size, num_heads, seq_len, head_dim]
553
- Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
554
- K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
555
- V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
564
+ Q = Q.view(B, L, self.num_heads, self.head_dim).transpose(
565
+ 1, 2
566
+ ) # [Batch, Heads, Length, head_dim]
567
+ K = K.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
568
+ V = V.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
569
+
570
+ key_padding_mask = None
571
+ if attention_mask is not None:
572
+ if attention_mask.dim() == 2: # [B,L], 1=valid, 0=pad
573
+ key_padding_mask = ~attention_mask.bool()
574
+ attn_mask = key_padding_mask[:, None, None, :]
575
+ attn_mask = attn_mask.expand(B, 1, L, L)
576
+ elif attention_mask.dim() == 3: # [B,L,L], 1=allowed, 0=masked
577
+ attn_mask = (~attention_mask.bool()).view(B, 1, L, L)
578
+ else:
579
+ raise ValueError("attention_mask must be [B,L] or [B,L,L]")
580
+ else:
581
+ attn_mask = None
556
582
 
557
583
  if self.use_flash_attention:
558
- # Use PyTorch 2.0+ Flash Attention
559
- if attention_mask is not None:
560
- # Convert mask to [batch_size, 1, seq_len, seq_len] format
561
- if attention_mask.dim() == 2:
562
- # [B, L] -> [B, 1, 1, L]
563
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
564
- elif attention_mask.dim() == 3:
565
- # [B, L, L] -> [B, 1, L, L]
566
- attention_mask = attention_mask.unsqueeze(1)
567
- attention_output = F.scaled_dot_product_attention(
584
+ attn = F.scaled_dot_product_attention(
568
585
  Q,
569
586
  K,
570
587
  V,
571
- attn_mask=attention_mask,
588
+ attn_mask=attn_mask,
572
589
  dropout_p=self.dropout_rate if self.training else 0.0,
573
- )
574
- # Handle potential NaN values
575
- attention_output = torch.nan_to_num(attention_output, nan=0.0)
590
+ ) # [B,H,L,dh]
576
591
  else:
577
- # Fallback to standard attention
578
592
  scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim**0.5)
593
+ if attn_mask is not None:
594
+ scores = scores.masked_fill(attn_mask, float("-inf"))
595
+ attn_weights = torch.softmax(scores, dim=-1)
596
+ attn_weights = self.dropout(attn_weights)
597
+ attn = torch.matmul(attn_weights, V) # [B,H,L,dh]
579
598
 
580
- if attention_mask is not None:
581
- # Process mask for standard attention
582
- if attention_mask.dim() == 2:
583
- # [B, L] -> [B, 1, 1, L]
584
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
585
- elif attention_mask.dim() == 3:
586
- # [B, L, L] -> [B, 1, L, L]
587
- attention_mask = attention_mask.unsqueeze(1)
588
- scores = scores.masked_fill(~attention_mask, float("-1e9"))
589
-
590
- attention_weights = F.softmax(scores, dim=-1)
591
- attention_weights = self.dropout(attention_weights)
592
- attention_output = torch.matmul(
593
- attention_weights, V
594
- ) # [batch_size, num_heads, seq_len, head_dim]
595
-
596
- # Concatenate heads
597
- attention_output = attention_output.transpose(1, 2).contiguous()
598
- attention_output = attention_output.view(
599
- batch_size, seq_len, self.embedding_dim
600
- )
599
+ attn = attn.transpose(1, 2).contiguous().view(B, L, D)
600
+ out = self.W_O(attn)
601
601
 
602
- # Output projection
603
- output = self.W_O(attention_output)
604
-
605
- # Residual connection
606
602
  if self.use_residual:
607
- output = output + self.W_Res(x)
608
-
609
- # Layer normalization
603
+ out = out + x
610
604
  if self.layer_norm is not None:
611
- output = self.layer_norm(output)
605
+ out = self.layer_norm(out)
612
606
 
613
- output = F.relu(output)
614
- return output
607
+ if key_padding_mask is not None:
608
+ out = out * (~key_padding_mask).unsqueeze(-1)
609
+
610
+ return out
615
611
 
616
612
 
617
613
  class AttentionPoolingLayer(nn.Module):
nextrec/basic/loggers.py CHANGED
@@ -2,7 +2,7 @@
2
2
  NextRec Basic Loggers
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 19/12/2025
5
+ Checkpoint: edit on 20/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -185,7 +185,7 @@ class TrainingLogger:
185
185
  ) -> dict[str, float]:
186
186
  formatted: dict[str, float] = {}
187
187
  for key, value in metrics.items():
188
- if isinstance(value, numbers.Number):
188
+ if isinstance(value, numbers.Real):
189
189
  formatted[f"{split}/{key}"] = float(value)
190
190
  elif hasattr(value, "item"):
191
191
  try:
nextrec/basic/metrics.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Metrics computation and configuration for model evaluation.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 19/12/2025
5
+ Checkpoint: edit on 20/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
@@ -49,8 +49,8 @@ TASK_DEFAULT_METRICS = {
49
49
 
50
50
  def check_user_id(*metric_sources: Any) -> bool:
51
51
  """Return True when GAUC or ranking@K metrics appear in the provided sources."""
52
- metric_names: set[str] = set()
53
- stack: list[Any] = list(metric_sources)
52
+ metric_names = set()
53
+ stack = list(metric_sources)
54
54
  while stack:
55
55
  item = stack.pop()
56
56
  if not item:
@@ -367,10 +367,12 @@ def configure_metrics(
367
367
  target_names: list[str], # ['target1', 'target2']
368
368
  ) -> tuple[list[str], dict[str, list[str]] | None, str]:
369
369
  """Configure metrics based on task and user input."""
370
+
370
371
  primary_task = task[0] if isinstance(task, list) else task
371
372
  nums_task = len(task) if isinstance(task, list) else 1
372
- metrics_list: list[str] = []
373
- task_specific_metrics: dict[str, list[str]] | None = None
373
+ metrics_list = []
374
+ task_specific_metrics = None
375
+
374
376
  if isinstance(metrics, dict):
375
377
  metrics_list = []
376
378
  task_specific_metrics = {}
@@ -462,6 +464,7 @@ def compute_single_metric(
462
464
  user_ids: np.ndarray | None = None,
463
465
  ) -> float:
464
466
  """Compute a single metric given true and predicted values."""
467
+
465
468
  y_p_binary = (y_pred > 0.5).astype(int)
466
469
  metric_lower = metric.lower()
467
470
  try:
@@ -575,6 +578,7 @@ def evaluate_metrics(
575
578
  user_ids: np.ndarray | None = None, # example: User IDs for GAUC computation
576
579
  ) -> dict: # {'auc': 0.75, 'logloss': 0.45, 'mse_target2': 3.2}
577
580
  """Evaluate specified metrics for given true and predicted values."""
581
+
578
582
  result = {}
579
583
  if y_true is None or y_pred is None:
580
584
  return result