nextrec 0.4.23__py3-none-any.whl → 0.4.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.23"
1
+ __version__ = "0.4.24"
nextrec/basic/layers.py CHANGED
@@ -20,6 +20,7 @@ import torch.nn.functional as F
20
20
  from nextrec.basic.activation import activation_layer
21
21
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
22
22
  from nextrec.utils.torch_utils import get_initializer
23
+ from nextrec.utils.types import ActivationName
23
24
 
24
25
 
25
26
  class PredictionLayer(nn.Module):
@@ -590,71 +591,48 @@ class MLP(nn.Module):
590
591
  def __init__(
591
592
  self,
592
593
  input_dim: int,
593
- output_layer: bool = True,
594
- dims: list[int] | None = None,
594
+ hidden_dims: list[int] | None = None,
595
+ output_dim: int | None = 1,
595
596
  dropout: float = 0.0,
596
- activation: Literal[
597
- "dice",
598
- "relu",
599
- "relu6",
600
- "elu",
601
- "selu",
602
- "leaky_relu",
603
- "prelu",
604
- "gelu",
605
- "sigmoid",
606
- "tanh",
607
- "softplus",
608
- "softsign",
609
- "hardswish",
610
- "mish",
611
- "silu",
612
- "swish",
613
- "hardsigmoid",
614
- "tanhshrink",
615
- "softshrink",
616
- "none",
617
- "linear",
618
- "identity",
619
- ] = "relu",
620
- use_norm: bool = True,
621
- norm_type: Literal["batch_norm", "layer_norm"] = "layer_norm",
597
+ activation: ActivationName = "relu",
598
+ norm_type: Literal["batch_norm", "layer_norm", "none"] = "none",
599
+ output_activation: ActivationName = "none",
622
600
  ):
623
601
  """
624
602
  Multi-Layer Perceptron (MLP) module.
625
603
 
626
604
  Args:
627
605
  input_dim: Dimension of the input features.
628
- output_layer: Whether to include the final output layer. If False, the MLP will output the last hidden layer, else it will output a single value.
629
- dims: List of hidden layer dimensions. If None, no hidden layers are added.
606
+ output_dim: Output dimension of the final layer. If None, no output layer is added.
607
+ hidden_dims: List of hidden layer dimensions. If None, no hidden layers are added.
630
608
  dropout: Dropout rate between layers.
631
609
  activation: Activation function to use between layers.
632
- use_norm: Whether to use normalization layers.
633
- norm_type: Type of normalization to use ("batch_norm" or "layer_norm").
610
+ norm_type: Type of normalization to use ("batch_norm", "layer_norm", or "none").
611
+ output_activation: Activation function applied after the output layer.
634
612
  """
635
613
  super().__init__()
636
- if dims is None:
637
- dims = []
614
+ hidden_dims = hidden_dims or []
638
615
  layers = []
639
616
  current_dim = input_dim
640
- for i_dim in dims:
617
+ for i_dim in hidden_dims:
641
618
  layers.append(nn.Linear(current_dim, i_dim))
642
- if use_norm:
643
- if norm_type == "batch_norm":
644
- # **IMPORTANT** be careful when using BatchNorm1d in distributed training, nextrec does not support sync batch norm now
645
- layers.append(nn.BatchNorm1d(i_dim))
646
- elif norm_type == "layer_norm":
647
- layers.append(nn.LayerNorm(i_dim))
648
- else:
649
- raise ValueError(f"Unsupported norm_type: {norm_type}")
619
+ if norm_type == "batch_norm":
620
+ # **IMPORTANT** be careful when using BatchNorm1d in distributed training, nextrec does not support sync batch norm now
621
+ layers.append(nn.BatchNorm1d(i_dim))
622
+ elif norm_type == "layer_norm":
623
+ layers.append(nn.LayerNorm(i_dim))
624
+ elif norm_type != "none":
625
+ raise ValueError(f"Unsupported norm_type: {norm_type}")
650
626
 
651
627
  layers.append(activation_layer(activation))
652
628
  layers.append(nn.Dropout(p=dropout))
653
629
  current_dim = i_dim
654
630
  # output layer
655
- if output_layer:
656
- layers.append(nn.Linear(current_dim, 1))
657
- self.output_dim = 1
631
+ if output_dim is not None:
632
+ layers.append(nn.Linear(current_dim, output_dim))
633
+ if output_activation != "none":
634
+ layers.append(activation_layer(output_activation))
635
+ self.output_dim = output_dim
658
636
  else:
659
637
  self.output_dim = current_dim
660
638
  self.mlp = nn.Sequential(*layers)
@@ -663,6 +641,47 @@ class MLP(nn.Module):
663
641
  return self.mlp(x)
664
642
 
665
643
 
644
+ class GateMLP(nn.Module):
645
+ """
646
+ Lightweight gate network: sigmoid MLP scaled by a constant factor.
647
+
648
+ Args:
649
+ input_dim: Dimension of the input features.
650
+ hidden_dim: Dimension of the hidden layer. If None, defaults to output_dim.
651
+ output_dim: Output dimension of the gate.
652
+ activation: Activation function to use in the hidden layer.
653
+ dropout: Dropout rate between layers.
654
+ use_bn: Whether to use batch normalization.
655
+ scale_factor: Scaling factor applied to the sigmoid output.
656
+ """
657
+
658
+ def __init__(
659
+ self,
660
+ input_dim: int,
661
+ hidden_dim: int | None,
662
+ output_dim: int,
663
+ activation: ActivationName = "relu",
664
+ dropout: float = 0.0,
665
+ use_bn: bool = False,
666
+ scale_factor: float = 2.0,
667
+ ) -> None:
668
+ super().__init__()
669
+ hidden_dim = output_dim if hidden_dim is None else hidden_dim
670
+ self.gate = MLP(
671
+ input_dim=input_dim,
672
+ hidden_dims=[hidden_dim],
673
+ output_dim=output_dim,
674
+ activation=activation,
675
+ dropout=dropout,
676
+ norm_type="batch_norm" if use_bn else "none",
677
+ output_activation="sigmoid",
678
+ )
679
+ self.scale_factor = scale_factor
680
+
681
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
682
+ return self.gate(inputs) * self.scale_factor
683
+
684
+
666
685
  class FM(nn.Module):
667
686
  def __init__(self, reduce_sum: bool = True):
668
687
  super().__init__()
@@ -1007,3 +1026,34 @@ class RMSNorm(torch.nn.Module):
1007
1026
  variance = torch.mean(x**2, dim=-1, keepdim=True)
1008
1027
  x_normalized = x * torch.rsqrt(variance + self.eps)
1009
1028
  return self.weight * x_normalized
1029
+
1030
+
1031
+ class DomainBatchNorm(nn.Module):
1032
+ """Domain-specific BatchNorm (applied per-domain with a shared interface)."""
1033
+
1034
+ def __init__(self, num_features: int, num_domains: int):
1035
+ super().__init__()
1036
+ if num_domains < 1:
1037
+ raise ValueError("num_domains must be >= 1")
1038
+ self.bns = nn.ModuleList(
1039
+ [nn.BatchNorm1d(num_features) for _ in range(num_domains)]
1040
+ )
1041
+
1042
+ def forward(self, x: torch.Tensor, domain_mask: torch.Tensor) -> torch.Tensor:
1043
+ if x.dim() != 2:
1044
+ raise ValueError("DomainBatchNorm expects 2D inputs [B, D].")
1045
+ output = x.clone()
1046
+ if domain_mask.dim() == 1:
1047
+ domain_ids = domain_mask.long()
1048
+ for idx, bn in enumerate(self.bns):
1049
+ mask = domain_ids == idx
1050
+ if mask.any():
1051
+ output[mask] = bn(x[mask])
1052
+ return output
1053
+ if domain_mask.dim() != 2:
1054
+ raise ValueError("domain_mask must be 1D indices or 2D one-hot mask.")
1055
+ for idx, bn in enumerate(self.bns):
1056
+ mask = domain_mask[:, idx] > 0
1057
+ if mask.any():
1058
+ output[mask] = bn(x[mask])
1059
+ return output
nextrec/basic/metrics.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Metrics computation and configuration for model evaluation.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 29/12/2025
5
+ Checkpoint: edit on 30/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
@@ -21,21 +21,9 @@ from sklearn.metrics import (
21
21
  recall_score,
22
22
  roc_auc_score,
23
23
  )
24
+ from nextrec.utils.types import TaskTypeName, MetricsName
25
+
24
26
 
25
- CLASSIFICATION_METRICS = {
26
- "auc",
27
- "gauc",
28
- "ks",
29
- "logloss",
30
- "accuracy",
31
- "acc",
32
- "precision",
33
- "recall",
34
- "f1",
35
- "micro_f1",
36
- "macro_f1",
37
- }
38
- REGRESSION_METRICS = {"mse", "mae", "rmse", "r2", "mape", "msle"}
39
27
  TASK_DEFAULT_METRICS = {
40
28
  "binary": ["auc", "gauc", "ks", "logloss", "accuracy", "precision", "recall", "f1"],
41
29
  "regression": ["mse", "mae", "rmse", "r2", "mape"],
@@ -58,7 +46,7 @@ def check_user_id(*metric_sources: Any) -> bool:
58
46
  stack.extend(item.values())
59
47
  continue
60
48
  if isinstance(item, str):
61
- metric_names.add(item.lower())
49
+ metric_names.add(item)
62
50
  continue
63
51
  try:
64
52
  stack.extend(item)
@@ -361,9 +349,9 @@ def compute_cosine_separation(y_true: np.ndarray, y_pred: np.ndarray) -> float:
361
349
 
362
350
 
363
351
  def configure_metrics(
364
- task: str | list[str], # 'binary' or ['binary', 'regression']
352
+ task: TaskTypeName | list[TaskTypeName], # 'binary' or ['binary', 'regression']
365
353
  metrics: (
366
- list[str] | dict[str, list[str]] | None
354
+ list[MetricsName] | dict[str, list[MetricsName]] | None
367
355
  ), # ['auc', 'logloss'] or {'task1': ['auc'], 'task2': ['mse']}
368
356
  target_names: list[str], # ['target1', 'target2']
369
357
  ) -> tuple[list[str], dict[str, list[str]] | None, str]:
@@ -383,13 +371,12 @@ def configure_metrics(
383
371
  f"[Metrics Warning] Task {task_name} not found in targets {target_names}, skipping its metrics"
384
372
  )
385
373
  continue
386
- lowered = [m.lower() for m in task_metrics]
387
- task_specific_metrics[task_name] = lowered
388
- for metric in lowered:
374
+ task_specific_metrics[task_name] = task_metrics
375
+ for metric in task_metrics:
389
376
  if metric not in metrics_list:
390
377
  metrics_list.append(metric)
391
378
  elif metrics:
392
- metrics_list = [m.lower() for m in metrics]
379
+ metrics_list = [m for m in metrics]
393
380
  else:
394
381
  # No user provided metrics, derive per task type
395
382
  if nums_task > 1 and isinstance(task, list):
@@ -416,11 +403,10 @@ def configure_metrics(
416
403
  return metrics_list, task_specific_metrics, best_metrics_mode
417
404
 
418
405
 
419
- def getbest_metric_mode(first_metric: str, primary_task: str) -> str:
406
+ def getbest_metric_mode(first_metric: MetricsName, primary_task: TaskTypeName) -> str:
420
407
  """Determine if metric should be maximized or minimized."""
421
- first_metric_lower = first_metric.lower()
422
408
  # Metrics that should be maximized
423
- if first_metric_lower in {
409
+ if first_metric in {
424
410
  "auc",
425
411
  "gauc",
426
412
  "ks",
@@ -436,20 +422,20 @@ def getbest_metric_mode(first_metric: str, primary_task: str) -> str:
436
422
  return "max"
437
423
  # Ranking metrics that should be maximized (with @K suffix)
438
424
  if (
439
- first_metric_lower.startswith("recall@")
440
- or first_metric_lower.startswith("precision@")
441
- or first_metric_lower.startswith("hitrate@")
442
- or first_metric_lower.startswith("hr@")
443
- or first_metric_lower.startswith("mrr@")
444
- or first_metric_lower.startswith("ndcg@")
445
- or first_metric_lower.startswith("map@")
425
+ first_metric.startswith("recall@")
426
+ or first_metric.startswith("precision@")
427
+ or first_metric.startswith("hitrate@")
428
+ or first_metric.startswith("hr@")
429
+ or first_metric.startswith("mrr@")
430
+ or first_metric.startswith("ndcg@")
431
+ or first_metric.startswith("map@")
446
432
  ):
447
433
  return "max"
448
434
  # Cosine separation should be maximized
449
- if first_metric_lower == "cosine":
435
+ if first_metric == "cosine":
450
436
  return "max"
451
437
  # Metrics that should be minimized
452
- if first_metric_lower in {"logloss", "mse", "mae", "rmse", "mape", "msle"}:
438
+ if first_metric in {"logloss", "mse", "mae", "rmse", "mape", "msle"}:
453
439
  return "min"
454
440
  # Default based on task type
455
441
  if primary_task == "regression":
@@ -458,7 +444,7 @@ def getbest_metric_mode(first_metric: str, primary_task: str) -> str:
458
444
 
459
445
 
460
446
  def compute_single_metric(
461
- metric: str,
447
+ metric: MetricsName,
462
448
  y_true: np.ndarray,
463
449
  y_pred: np.ndarray,
464
450
  task_type: str,
@@ -466,30 +452,32 @@ def compute_single_metric(
466
452
  ) -> float:
467
453
  """Compute a single metric given true and predicted values."""
468
454
 
455
+ if y_true.size == 0:
456
+ return 0.0
457
+
469
458
  y_p_binary = (y_pred > 0.5).astype(int)
470
- metric_lower = metric.lower()
471
459
  try:
472
- if metric_lower.startswith("recall@"):
473
- k = int(metric_lower.split("@")[1])
460
+ if metric.startswith("recall@"):
461
+ k = int(metric.split("@")[1])
474
462
  return compute_recall_at_k(y_true, y_pred, user_ids, k) # type: ignore
475
- if metric_lower.startswith("precision@"):
476
- k = int(metric_lower.split("@")[1])
463
+ if metric.startswith("precision@"):
464
+ k = int(metric.split("@")[1])
477
465
  return compute_precision_at_k(y_true, y_pred, user_ids, k) # type: ignore
478
- if metric_lower.startswith("hitrate@") or metric_lower.startswith("hr@"):
479
- k_str = metric_lower.split("@")[1]
466
+ if metric.startswith("hitrate@") or metric.startswith("hr@"):
467
+ k_str = metric.split("@")[1]
480
468
  k = int(k_str)
481
469
  return compute_hitrate_at_k(y_true, y_pred, user_ids, k) # type: ignore
482
- if metric_lower.startswith("mrr@"):
483
- k = int(metric_lower.split("@")[1])
470
+ if metric.startswith("mrr@"):
471
+ k = int(metric.split("@")[1])
484
472
  return compute_mrr_at_k(y_true, y_pred, user_ids, k) # type: ignore
485
- if metric_lower.startswith("ndcg@"):
486
- k = int(metric_lower.split("@")[1])
473
+ if metric.startswith("ndcg@"):
474
+ k = int(metric.split("@")[1])
487
475
  return compute_ndcg_at_k(y_true, y_pred, user_ids, k) # type: ignore
488
- if metric_lower.startswith("map@"):
489
- k = int(metric_lower.split("@")[1])
476
+ if metric.startswith("map@"):
477
+ k = int(metric.split("@")[1])
490
478
  return compute_map_at_k(y_true, y_pred, user_ids, k) # type: ignore
491
479
  # cosine for matching task
492
- if metric_lower == "cosine":
480
+ if metric == "cosine":
493
481
  return compute_cosine_separation(y_true, y_pred)
494
482
  if metric == "auc":
495
483
  value = float(
@@ -570,15 +558,31 @@ def compute_single_metric(
570
558
  def evaluate_metrics(
571
559
  y_true: np.ndarray | None,
572
560
  y_pred: np.ndarray | None,
573
- metrics: list[str], # example: ['auc', 'logloss']
574
- task: str | list[str], # example: 'binary' or ['binary', 'regression']
575
- target_names: list[str], # example: ['target1', 'target2']
576
- task_specific_metrics: (
577
- dict[str, list[str]] | None
578
- ) = None, # example: {'target1': ['auc', 'logloss'], 'target2': ['mse']}
579
- user_ids: np.ndarray | None = None, # example: User IDs for GAUC computation
580
- ) -> dict: # {'auc': 0.75, 'logloss': 0.45, 'mse_target2': 3.2}
581
- """Evaluate specified metrics for given true and predicted values."""
561
+ metrics: list[MetricsName],
562
+ task: TaskTypeName | list[TaskTypeName],
563
+ target_names: list[str],
564
+ task_specific_metrics: dict[str, list[MetricsName]] | None = None,
565
+ user_ids: np.ndarray | None = None,
566
+ ignore_label: int | float | None = None,
567
+ ) -> dict:
568
+ """
569
+ Evaluate specified metrics for given true and predicted values.
570
+ Supports single-task and multi-task evaluation.
571
+ Handles optional ignore_label to exclude certain samples.
572
+
573
+ Args:
574
+ y_true: Ground truth labels.
575
+ y_pred: Predicted values.
576
+ metrics: List of metric names to compute.
577
+ task: Task type(s) - 'binary', 'regression', etc.
578
+ target_names: Names of target variables. e.g., ['target1', 'target2']
579
+ task_specific_metrics: Optional dict mapping target names to specific metrics. e.g., {'target1': ['auc', 'logloss'], 'target2': ['mse']}
580
+ user_ids: Optional user IDs for GAUC and ranking metrics. e.g., User IDs for GAUC computation
581
+ ignore_label: Optional label value to ignore during evaluation.
582
+
583
+ Returns: Dictionary of computed metric values. {'auc': 0.75, 'logloss': 0.45, 'mse_target2': 3.2}
584
+
585
+ """
582
586
 
583
587
  result = {}
584
588
  if y_true is None or y_pred is None:
@@ -588,70 +592,81 @@ def evaluate_metrics(
588
592
  nums_task = len(task) if isinstance(task, list) else 1
589
593
  # Single task evaluation
590
594
  if nums_task == 1:
595
+ if ignore_label is not None:
596
+ valid_mask = y_true != ignore_label
597
+ if np.any(valid_mask):
598
+ y_true = y_true[valid_mask]
599
+ y_pred = y_pred[valid_mask]
600
+ if user_ids is not None:
601
+ user_ids = user_ids[valid_mask]
602
+ else:
603
+ return result
591
604
  for metric in metrics:
592
- metric_lower = metric.lower()
593
605
  value = compute_single_metric(
594
- metric_lower, y_true, y_pred, primary_task, user_ids
606
+ metric, y_true, y_pred, primary_task, user_ids
595
607
  )
596
- result[metric_lower] = value
608
+ result[metric] = value
597
609
  # Multi-task evaluation
598
610
  else:
599
- for metric in metrics:
600
- metric_lower = metric.lower()
601
- for task_idx in range(nums_task):
602
- # Check if metric should be computed for given task
603
- should_compute = True
604
- if task_specific_metrics is not None and task_idx < len(target_names):
605
- task_name = target_names[task_idx]
606
- should_compute = metric_lower in task_specific_metrics.get(
607
- task_name, []
608
- )
609
- else:
610
- # Get task type for specific index
611
- if isinstance(task, list) and task_idx < len(task):
612
- task_type = task[task_idx]
613
- elif isinstance(task, str):
614
- task_type = task
615
- else:
616
- task_type = "binary"
617
- if task_type in ["binary", "multilabel"]:
618
- should_compute = metric_lower in {
619
- "auc",
620
- "gauc",
621
- "ks",
622
- "logloss",
623
- "accuracy",
624
- "acc",
625
- "precision",
626
- "recall",
627
- "f1",
628
- "micro_f1",
629
- "macro_f1",
630
- }
631
- elif task_type == "regression":
632
- should_compute = metric_lower in {
633
- "mse",
634
- "mae",
635
- "rmse",
636
- "r2",
637
- "mape",
638
- "msle",
639
- }
640
- if not should_compute:
611
+ task_types = []
612
+ for task_idx in range(nums_task):
613
+ if isinstance(task, list) and task_idx < len(task):
614
+ task_types.append(task[task_idx])
615
+ elif isinstance(task, str):
616
+ task_types.append(task)
617
+ else:
618
+ task_types.append("binary")
619
+ metric_allowlist = {
620
+ "binary": {
621
+ "auc",
622
+ "gauc",
623
+ "ks",
624
+ "logloss",
625
+ "accuracy",
626
+ "acc",
627
+ "precision",
628
+ "recall",
629
+ "f1",
630
+ "micro_f1",
631
+ "macro_f1",
632
+ },
633
+ "regression": {
634
+ "mse",
635
+ "mae",
636
+ "rmse",
637
+ "r2",
638
+ "mape",
639
+ "msle",
640
+ },
641
+ }
642
+ for task_idx in range(nums_task):
643
+ task_type = task_types[task_idx]
644
+ target_name = target_names[task_idx]
645
+ if task_specific_metrics is not None and task_idx < len(target_names):
646
+ allowed_metrics = {
647
+ m for m in task_specific_metrics.get(target_name, [])
648
+ }
649
+ else:
650
+ allowed_metrics = metric_allowlist.get(task_type)
651
+ for metric in metrics:
652
+ if allowed_metrics is not None and metric not in allowed_metrics:
641
653
  continue
642
- target_name = target_names[task_idx]
643
- # Get task type for specific index
644
- if isinstance(task, list) and task_idx < len(task):
645
- task_type = task[task_idx]
646
- elif isinstance(task, str):
647
- task_type = task
648
- else:
649
- task_type = "binary"
650
654
  y_true_task = y_true[:, task_idx]
651
655
  y_pred_task = y_pred[:, task_idx]
656
+ task_user_ids = user_ids
657
+ if ignore_label is not None:
658
+ valid_mask = y_true_task != ignore_label
659
+ if np.any(valid_mask):
660
+ y_true_task = y_true_task[valid_mask]
661
+ y_pred_task = y_pred_task[valid_mask]
662
+ if task_user_ids is not None:
663
+ task_user_ids = task_user_ids[valid_mask]
664
+ else:
665
+ result[f"{metric}_{target_name}"] = 0.0
666
+ continue
652
667
  # Compute metric
653
668
  value = compute_single_metric(
654
- metric_lower, y_true_task, y_pred_task, task_type, user_ids
669
+ metric, y_true_task, y_pred_task, task_type, task_user_ids
655
670
  )
656
- result[f"{metric_lower}_{target_name}"] = value
671
+ result[f"{metric}_{target_name}"] = value
657
672
  return result
nextrec/basic/model.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 29/12/2025
5
+ Checkpoint: edit on 30/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
@@ -646,6 +646,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
646
646
  sampler=sampler,
647
647
  collate_fn=collate_fn,
648
648
  num_workers=num_workers,
649
+ pin_memory=self.device.type == "cuda",
650
+ persistent_workers=num_workers > 0,
649
651
  )
650
652
  return (loader, dataset) if return_dataset else loader
651
653
 
@@ -1119,16 +1121,17 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1119
1121
  train_log_payload, step=epoch + 1, split="train"
1120
1122
  )
1121
1123
  if valid_loader is not None:
1122
- self.callbacks.on_validation_begin()
1123
- val_metrics = self.evaluate(
1124
- valid_loader,
1125
- user_ids=valid_user_ids if self.needs_user_ids else None,
1126
- num_workers=num_workers,
1127
- )
1128
- should_log_valid = (epoch + 1) % log_interval == 0 or (
1124
+ should_eval_valid = (epoch + 1) % log_interval == 0 or (
1129
1125
  epoch + 1
1130
1126
  ) == epochs
1131
- if should_log_valid:
1127
+ val_metrics = None
1128
+ if should_eval_valid:
1129
+ self.callbacks.on_validation_begin()
1130
+ val_metrics = self.evaluate(
1131
+ valid_loader,
1132
+ user_ids=valid_user_ids if self.needs_user_ids else None,
1133
+ num_workers=num_workers,
1134
+ )
1132
1135
  display_metrics_table(
1133
1136
  epoch=epoch + 1,
1134
1137
  epochs=epochs,
@@ -1142,23 +1145,24 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1142
1145
  is_main_process=self.is_main_process,
1143
1146
  colorize=lambda s: colorize(" " + s, color="cyan"),
1144
1147
  )
1145
- self.callbacks.on_validation_end()
1146
- if should_log_valid and val_metrics and self.training_logger:
1147
- self.training_logger.log_metrics(
1148
- val_metrics, step=epoch + 1, split="valid"
1149
- )
1148
+ self.callbacks.on_validation_end()
1149
+ if val_metrics and self.training_logger:
1150
+ self.training_logger.log_metrics(
1151
+ val_metrics, step=epoch + 1, split="valid"
1152
+ )
1150
1153
  if not val_metrics:
1151
- if self.is_main_process:
1154
+ if should_eval_valid and self.is_main_process:
1152
1155
  logging.info(
1153
1156
  colorize(
1154
1157
  "Warning: No validation metrics computed. Skipping validation for this epoch.",
1155
1158
  color="yellow",
1156
1159
  )
1157
1160
  )
1158
- continue
1159
- epoch_logs = {**train_log_payload}
1160
- for k, v in val_metrics.items():
1161
- epoch_logs[f"val_{k}"] = v
1161
+ epoch_logs = {**train_log_payload}
1162
+ else:
1163
+ epoch_logs = {**train_log_payload}
1164
+ for k, v in val_metrics.items():
1165
+ epoch_logs[f"val_{k}"] = v
1162
1166
  else:
1163
1167
  epoch_logs = {**train_log_payload}
1164
1168
  if self.is_main_process:
@@ -1340,6 +1344,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1340
1344
  target_names=self.target_columns,
1341
1345
  task_specific_metrics=self.task_specific_metrics,
1342
1346
  user_ids=combined_user_ids,
1347
+ ignore_label=self.ignore_label,
1343
1348
  )
1344
1349
  return avg_loss, metrics_dict
1345
1350
  return avg_loss
@@ -1387,6 +1392,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1387
1392
  sampler=valid_sampler,
1388
1393
  collate_fn=collate_fn,
1389
1394
  num_workers=num_workers,
1395
+ pin_memory=self.device.type == "cuda",
1396
+ persistent_workers=num_workers > 0,
1390
1397
  )
1391
1398
  valid_user_ids = None
1392
1399
  if needs_user_ids:
@@ -1532,6 +1539,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1532
1539
  target_names=self.target_columns,
1533
1540
  task_specific_metrics=self.task_specific_metrics,
1534
1541
  user_ids=final_user_ids,
1542
+ ignore_label=self.ignore_label,
1535
1543
  )
1536
1544
  return metrics_dict
1537
1545
 
@@ -282,6 +282,8 @@ class RecDataLoader(FeatureSet):
282
282
  sampler=sampler,
283
283
  collate_fn=collate_fn,
284
284
  num_workers=num_workers,
285
+ pin_memory=torch.cuda.is_available(),
286
+ persistent_workers=num_workers > 0,
285
287
  )
286
288
 
287
289
  def create_from_path(
File without changes
File without changes
File without changes