nextrec 0.4.6__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/basic/model.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 05/12/2025
5
+ Checkpoint: edit on 18/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
@@ -25,7 +25,13 @@ from torch.utils.data import DataLoader
25
25
  from torch.utils.data.distributed import DistributedSampler
26
26
  from torch.nn.parallel import DistributedDataParallel as DDP
27
27
 
28
- from nextrec.basic.callback import EarlyStopper
28
+ from nextrec.basic.callback import (
29
+ EarlyStopper,
30
+ CallbackList,
31
+ Callback,
32
+ CheckpointSaver,
33
+ LearningRateScheduler,
34
+ )
29
35
  from nextrec.basic.features import (
30
36
  DenseFeature,
31
37
  SparseFeature,
@@ -42,7 +48,14 @@ from nextrec.data.dataloader import build_tensors_from_data
42
48
  from nextrec.data.batch_utils import collate_fn, batch_to_dict
43
49
  from nextrec.data.data_processing import get_column_data, get_user_ids
44
50
 
45
- from nextrec.loss import get_loss_fn, get_loss_kwargs
51
+ from nextrec.loss import (
52
+ BPRLoss,
53
+ HingeLoss,
54
+ InfoNCELoss,
55
+ SampledSoftmaxLoss,
56
+ TripletLoss,
57
+ get_loss_fn,
58
+ )
46
59
  from nextrec.utils.tensor import to_tensor
47
60
  from nextrec.utils.device import configure_device
48
61
  from nextrec.utils.optimizer import get_optimizer, get_scheduler
@@ -71,13 +84,14 @@ class BaseModel(FeatureSet, nn.Module):
71
84
  target: list[str] | str | None = None,
72
85
  id_columns: list[str] | str | None = None,
73
86
  task: str | list[str] | None = None,
74
- device: str = "cpu",
75
- early_stop_patience: int = 20,
76
- session_id: str | None = None,
77
87
  embedding_l1_reg: float = 0.0,
78
88
  dense_l1_reg: float = 0.0,
79
89
  embedding_l2_reg: float = 0.0,
80
90
  dense_l2_reg: float = 0.0,
91
+ device: str = "cpu",
92
+ early_stop_patience: int = 20,
93
+ session_id: str | None = None,
94
+ callbacks: list[Callback] | None = None,
81
95
  distributed: bool = False,
82
96
  rank: int | None = None,
83
97
  world_size: int | None = None,
@@ -91,16 +105,20 @@ class BaseModel(FeatureSet, nn.Module):
91
105
  dense_features: DenseFeature definitions.
92
106
  sparse_features: SparseFeature definitions.
93
107
  sequence_features: SequenceFeature definitions.
94
- target: Target column name.
95
- id_columns: Identifier column name, only need to specify if GAUC is required.
108
+ target: Target column name. e.g., 'label' or ['label1', 'label2'].
109
+ id_columns: Identifier column name, only need to specify if GAUC is required. e.g., 'user_id'.
96
110
  task: Task types, e.g., 'binary', 'regression', or ['binary', 'regression']. If None, falls back to self.default_task.
97
- device: Torch device string or torch.device. e.g., 'cpu', 'cuda:0'.
111
+
98
112
  embedding_l1_reg: L1 regularization strength for embedding params. e.g., 1e-6.
99
113
  dense_l1_reg: L1 regularization strength for dense params. e.g., 1e-5.
100
114
  embedding_l2_reg: L2 regularization strength for embedding params. e.g., 1e-5.
101
115
  dense_l2_reg: L2 regularization strength for dense params. e.g., 1e-4.
116
+
117
+ device: Torch device string or torch.device. e.g., 'cpu', 'cuda:0'.
102
118
  early_stop_patience: Epochs for early stopping. 0 to disable. e.g., 20.
103
- session_id: Session id for logging. If None, a default id with timestamps will be created.
119
+ session_id: Session id for logging. If None, a default id with timestamps will be created. e.g., 'session_tutorial'.
120
+ callbacks: List of callback instances. If None, default callbacks will be created. e.g., [EarlyStopper(), CheckpointSaver()].
121
+
104
122
  distributed: Enable DistributedDataParallel flow, set True to enable distributed training.
105
123
  rank: Global rank (defaults to env RANK).
106
124
  world_size: Number of processes (defaults to env WORLD_SIZE).
@@ -151,6 +169,7 @@ class BaseModel(FeatureSet, nn.Module):
151
169
  self.max_gradient_norm = 1.0
152
170
  self.logger_initialized = False
153
171
  self.training_logger = None
172
+ self.callbacks = CallbackList(callbacks) if callbacks else CallbackList()
154
173
 
155
174
  def register_regularization_weights(
156
175
  self,
@@ -162,8 +181,24 @@ class BaseModel(FeatureSet, nn.Module):
162
181
  include_modules = include_modules or []
163
182
  embedding_layer = getattr(self, embedding_attr, None)
164
183
  embed_dict = getattr(embedding_layer, "embed_dict", None)
184
+ embedding_params: list[torch.Tensor] = []
165
185
  if embed_dict is not None:
166
- self.embedding_params.extend(embed.weight for embed in embed_dict.values())
186
+ embedding_params.extend(
187
+ embed.weight
188
+ for embed in embed_dict.values()
189
+ if hasattr(embed, "weight")
190
+ )
191
+ else:
192
+ weight = getattr(embedding_layer, "weight", None)
193
+ if isinstance(weight, torch.Tensor):
194
+ embedding_params.append(weight)
195
+
196
+ existing_embedding_ids = {id(param) for param in self.embedding_params}
197
+ for param in embedding_params:
198
+ if id(param) not in existing_embedding_ids:
199
+ self.embedding_params.append(param)
200
+ existing_embedding_ids.add(id(param))
201
+
167
202
  skip_types = (
168
203
  nn.BatchNorm1d,
169
204
  nn.BatchNorm2d,
@@ -172,6 +207,7 @@ class BaseModel(FeatureSet, nn.Module):
172
207
  nn.Dropout2d,
173
208
  nn.Dropout3d,
174
209
  )
210
+ existing_reg_ids = {id(param) for param in self.regularization_weights}
175
211
  for name, module in self.named_modules():
176
212
  if (
177
213
  module is self
@@ -182,7 +218,9 @@ class BaseModel(FeatureSet, nn.Module):
182
218
  ):
183
219
  continue
184
220
  if isinstance(module, nn.Linear):
185
- self.regularization_weights.append(module.weight)
221
+ if id(module.weight) not in existing_reg_ids:
222
+ self.regularization_weights.append(module.weight)
223
+ existing_reg_ids.add(id(module.weight))
186
224
 
187
225
  def add_reg_loss(self) -> torch.Tensor:
188
226
  reg_loss = torch.tensor(0.0, device=self.device)
@@ -335,6 +373,7 @@ class BaseModel(FeatureSet, nn.Module):
335
373
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
336
374
  loss_params: dict | list[dict] | None = None,
337
375
  loss_weights: int | float | list[int | float] | None = None,
376
+ callbacks: list[Callback] | None = None,
338
377
  ):
339
378
  """
340
379
  Configure the model for training.
@@ -346,6 +385,7 @@ class BaseModel(FeatureSet, nn.Module):
346
385
  loss: Loss function name, instance, or list for multi-task. e.g., 'bce', 'mse', or torch.nn.BCELoss(), you can also use custom loss functions.
347
386
  loss_params: Loss function parameters, or list for multi-task. e.g., {'weight': tensor([0.25, 0.75])}.
348
387
  loss_weights: Weights for each task loss, int/float for single-task or list for multi-task. e.g., 1.0, or [1.0, 0.5].
388
+ callbacks: Additional callbacks to add to the existing callback list. e.g., [EarlyStopper(), CheckpointSaver()].
349
389
  """
350
390
  if loss_params is None:
351
391
  self.loss_params = {}
@@ -427,6 +467,11 @@ class BaseModel(FeatureSet, nn.Module):
427
467
  )
428
468
  self.loss_weights = weights
429
469
 
470
+ # Add callbacks from compile if provided
471
+ if callbacks:
472
+ for callback in callbacks:
473
+ self.callbacks.append(callback)
474
+
430
475
  def compute_loss(self, y_pred, y_true):
431
476
  if y_true is None:
432
477
  raise ValueError(
@@ -580,6 +625,53 @@ class BaseModel(FeatureSet, nn.Module):
580
625
  task=self.task, metrics=metrics, target_names=self.target_columns
581
626
  )
582
627
  ) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
628
+
629
+ # Setup default callbacks if none exist
630
+ if len(self.callbacks.callbacks) == 0:
631
+ if self.nums_task == 1:
632
+ monitor_metric = f"val_{self.metrics[0]}"
633
+ else:
634
+ monitor_metric = f"val_{self.metrics[0]}_{self.target_columns[0]}"
635
+
636
+ if self.early_stop_patience > 0:
637
+ self.callbacks.append(
638
+ EarlyStopper(
639
+ monitor=monitor_metric,
640
+ patience=self.early_stop_patience,
641
+ mode=self.best_metrics_mode,
642
+ restore_best_weights=not self.distributed,
643
+ verbose=1 if self.is_main_process else 0,
644
+ )
645
+ )
646
+
647
+ if self.is_main_process:
648
+ self.callbacks.append(
649
+ CheckpointSaver(
650
+ save_path=self.best_path,
651
+ monitor=monitor_metric,
652
+ mode=self.best_metrics_mode,
653
+ save_best_only=True,
654
+ verbose=1,
655
+ )
656
+ )
657
+
658
+ if self.scheduler_fn is not None:
659
+ self.callbacks.append(
660
+ LearningRateScheduler(
661
+ scheduler=self.scheduler_fn,
662
+ verbose=1 if self.is_main_process else 0,
663
+ )
664
+ )
665
+
666
+ self.callbacks.set_model(self)
667
+ self.callbacks.set_params(
668
+ {
669
+ "epochs": epochs,
670
+ "batch_size": batch_size,
671
+ "metrics": self.metrics,
672
+ }
673
+ )
674
+
583
675
  self.early_stopper = EarlyStopper(
584
676
  patience=self.early_stop_patience, mode=self.best_metrics_mode
585
677
  )
@@ -648,7 +740,17 @@ class BaseModel(FeatureSet, nn.Module):
648
740
  else:
649
741
  train_loader = train_data
650
742
  else:
651
- loader, dataset = self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, return_dataset=True) # type: ignore
743
+ result = self.prepare_data_loader(
744
+ train_data,
745
+ batch_size=batch_size,
746
+ shuffle=shuffle,
747
+ num_workers=num_workers,
748
+ return_dataset=True,
749
+ )
750
+ assert isinstance(
751
+ result, tuple
752
+ ), "Expected tuple from prepare_data_loader with return_dataset=True"
753
+ loader, dataset = result
652
754
  if (
653
755
  auto_distributed_sampler
654
756
  and self.distributed
@@ -729,8 +831,13 @@ class BaseModel(FeatureSet, nn.Module):
729
831
  logging.info("")
730
832
  logging.info(colorize(f"Model device: {self.device}", bold=True))
731
833
 
834
+ self.callbacks.on_train_begin()
835
+
732
836
  for epoch in range(epochs):
733
837
  self.epoch_index = epoch
838
+
839
+ self.callbacks.on_epoch_begin(epoch)
840
+
734
841
  if is_streaming and self.is_main_process:
735
842
  logging.info("")
736
843
  logging.info(
@@ -740,10 +847,16 @@ class BaseModel(FeatureSet, nn.Module):
740
847
  # handle train result
741
848
  if (
742
849
  self.distributed
850
+ and isinstance(train_loader, DataLoader)
743
851
  and hasattr(train_loader, "sampler")
744
852
  and isinstance(train_loader.sampler, DistributedSampler)
745
853
  ):
746
854
  train_loader.sampler.set_epoch(epoch)
855
+ # Type guard: ensure train_loader is DataLoader for train_epoch
856
+ if not isinstance(train_loader, DataLoader):
857
+ raise TypeError(
858
+ f"Expected DataLoader for training, got {type(train_loader)}"
859
+ )
747
860
  train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
748
861
  if isinstance(train_result, tuple): # [avg_loss, metrics_dict]
749
862
  train_loss, train_metrics = train_result
@@ -803,6 +916,9 @@ class BaseModel(FeatureSet, nn.Module):
803
916
  train_log_payload, step=epoch + 1, split="train"
804
917
  )
805
918
  if valid_loader is not None:
919
+ # Call on_validation_begin
920
+ self.callbacks.on_validation_begin()
921
+
806
922
  # pass user_ids only if needed for GAUC metric
807
923
  val_metrics = self.evaluate(
808
924
  valid_loader,
@@ -849,17 +965,17 @@ class BaseModel(FeatureSet, nn.Module):
849
965
  color="cyan",
850
966
  )
851
967
  )
968
+
969
+ # Call on_validation_end
970
+ self.callbacks.on_validation_end()
852
971
  if val_metrics and self.training_logger:
853
972
  self.training_logger.log_metrics(
854
973
  val_metrics, step=epoch + 1, split="valid"
855
974
  )
975
+
856
976
  # Handle empty validation metrics
857
977
  if not val_metrics:
858
978
  if self.is_main_process:
859
- self.save_model(
860
- self.checkpoint_path, add_timestamp=False, verbose=False
861
- )
862
- self.best_checkpoint_path = self.checkpoint_path
863
979
  logging.info(
864
980
  colorize(
865
981
  "Warning: No validation metrics computed. Skipping validation for this epoch.",
@@ -867,81 +983,26 @@ class BaseModel(FeatureSet, nn.Module):
867
983
  )
868
984
  )
869
985
  continue
870
- if self.nums_task == 1:
871
- primary_metric_key = self.metrics[0]
872
- else:
873
- primary_metric_key = f"{self.metrics[0]}_{self.target_columns[0]}"
874
- primary_metric = val_metrics.get(
875
- primary_metric_key, val_metrics[list(val_metrics.keys())[0]]
876
- ) # get primary metric value, default to first metric if not found
877
-
878
- # In distributed mode, broadcast primary_metric to ensure all processes use the same value
879
- if self.distributed and dist.is_available() and dist.is_initialized():
880
- metric_tensor = torch.tensor(
881
- [primary_metric], device=self.device, dtype=torch.float32
882
- )
883
- dist.broadcast(metric_tensor, src=0)
884
- primary_metric = float(metric_tensor.item())
885
-
886
- improved = False
887
- # early stopping check
888
- if self.best_metrics_mode == "max":
889
- if primary_metric > self.best_metric:
890
- self.best_metric = primary_metric
891
- improved = True
892
- else:
893
- if primary_metric < self.best_metric:
894
- self.best_metric = primary_metric
895
- improved = True
896
986
 
897
- # save checkpoint and best model for main process
898
- if self.is_main_process:
899
- self.save_model(
900
- self.checkpoint_path, add_timestamp=False, verbose=False
901
- )
902
- logging.info(" ")
903
- if improved:
904
- logging.info(
905
- colorize(
906
- f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"
907
- )
908
- )
909
- self.save_model(
910
- self.best_path, add_timestamp=False, verbose=False
911
- )
912
- self.best_checkpoint_path = self.best_path
913
- self.early_stopper.trial_counter = 0
914
- else:
915
- self.early_stopper.trial_counter += 1
916
- logging.info(
917
- colorize(
918
- f"No improvement for {self.early_stopper.trial_counter} epoch(s)"
919
- )
920
- )
921
- if self.early_stopper.trial_counter >= self.early_stopper.patience:
922
- self.stop_training = True
923
- logging.info(
924
- colorize(
925
- f"Early stopping triggered after {epoch + 1} epochs",
926
- color="bright_red",
927
- bold=True,
928
- )
929
- )
930
- else:
931
- # Non-main processes also update trial_counter to keep in sync
932
- if improved:
933
- self.early_stopper.trial_counter = 0
934
- else:
935
- self.early_stopper.trial_counter += 1
987
+ # Prepare epoch logs for callbacks
988
+ epoch_logs = {**train_log_payload}
989
+ if val_metrics:
990
+ # Add val_ prefix to validation metrics
991
+ for k, v in val_metrics.items():
992
+ epoch_logs[f"val_{k}"] = v
936
993
  else:
994
+ # No validation data
995
+ epoch_logs = {**train_log_payload}
937
996
  if self.is_main_process:
938
997
  self.save_model(
939
998
  self.checkpoint_path, add_timestamp=False, verbose=False
940
999
  )
941
- self.save_model(self.best_path, add_timestamp=False, verbose=False)
942
- self.best_checkpoint_path = self.best_path
1000
+ self.best_checkpoint_path = self.checkpoint_path
943
1001
 
944
- # Broadcast stop_training flag to all processes (always, regardless of validation)
1002
+ # Call on_epoch_end for all callbacks (handles early stopping, checkpointing, lr scheduling)
1003
+ self.callbacks.on_epoch_end(epoch, epoch_logs)
1004
+
1005
+ # Broadcast stop_training flag to all processes
945
1006
  if self.distributed and dist.is_available() and dist.is_initialized():
946
1007
  stop_tensor = torch.tensor(
947
1008
  [int(self.stop_training)], device=self.device
@@ -951,14 +1012,9 @@ class BaseModel(FeatureSet, nn.Module):
951
1012
 
952
1013
  if self.stop_training:
953
1014
  break
954
- if self.scheduler_fn is not None:
955
- if isinstance(
956
- self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau
957
- ):
958
- if valid_loader is not None:
959
- self.scheduler_fn.step(primary_metric)
960
- else:
961
- self.scheduler_fn.step()
1015
+ # Call on_train_end for all callbacks
1016
+ self.callbacks.on_train_end()
1017
+
962
1018
  if self.distributed and dist.is_available() and dist.is_initialized():
963
1019
  dist.barrier() # dist.barrier() will wait for all processes, like async all_reduce()
964
1020
  if self.is_main_process:
@@ -970,9 +1026,17 @@ class BaseModel(FeatureSet, nn.Module):
970
1026
  logging.info(
971
1027
  colorize(f"Load best model from: {self.best_checkpoint_path}")
972
1028
  )
973
- self.load_model(
974
- self.best_checkpoint_path, map_location=self.device, verbose=False
975
- )
1029
+ if os.path.exists(self.best_checkpoint_path):
1030
+ self.load_model(
1031
+ self.best_checkpoint_path, map_location=self.device, verbose=False
1032
+ )
1033
+ elif self.is_main_process:
1034
+ logging.info(
1035
+ colorize(
1036
+ f"Warning: Best checkpoint not found at {self.best_checkpoint_path}, skip loading best model.",
1037
+ color="yellow",
1038
+ )
1039
+ )
976
1040
  if self.training_logger:
977
1041
  self.training_logger.close()
978
1042
  return self
@@ -1847,7 +1911,9 @@ class BaseModel(FeatureSet, nn.Module):
1847
1911
  class BaseMatchModel(BaseModel):
1848
1912
  """
1849
1913
  Base class for match (retrieval/recall) models
1850
- Supports pointwise, pairwise, and listwise training modes
1914
+
1915
+ - Pointwise: predicts a user-item match score/probability using labels (default target: 'label')
1916
+ - Pairwise/Listwise: trains with in-batch negatives; labels can be omitted by setting target=None
1851
1917
  """
1852
1918
 
1853
1919
  @property
@@ -1887,6 +1953,16 @@ class BaseMatchModel(BaseModel):
1887
1953
  embedding_l2_reg: float = 0.0,
1888
1954
  dense_l2_reg: float = 0.0,
1889
1955
  early_stop_patience: int = 20,
1956
+ target: list[str] | str | None = "label",
1957
+ id_columns: list[str] | str | None = None,
1958
+ task: str | list[str] | None = None,
1959
+ session_id: str | None = None,
1960
+ callbacks: list[Callback] | None = None,
1961
+ distributed: bool = False,
1962
+ rank: int | None = None,
1963
+ world_size: int | None = None,
1964
+ local_rank: int | None = None,
1965
+ ddp_find_unused_parameters: bool = False,
1890
1966
  **kwargs,
1891
1967
  ):
1892
1968
 
@@ -1911,14 +1987,22 @@ class BaseMatchModel(BaseModel):
1911
1987
  dense_features=all_dense_features,
1912
1988
  sparse_features=all_sparse_features,
1913
1989
  sequence_features=all_sequence_features,
1914
- target=["label"],
1915
- task="binary",
1990
+ target=target,
1991
+ id_columns=id_columns,
1992
+ task=task,
1916
1993
  device=device,
1917
1994
  embedding_l1_reg=embedding_l1_reg,
1918
1995
  dense_l1_reg=dense_l1_reg,
1919
1996
  embedding_l2_reg=embedding_l2_reg,
1920
1997
  dense_l2_reg=dense_l2_reg,
1921
1998
  early_stop_patience=early_stop_patience,
1999
+ session_id=session_id,
2000
+ callbacks=callbacks,
2001
+ distributed=distributed,
2002
+ rank=rank,
2003
+ world_size=world_size,
2004
+ local_rank=local_rank,
2005
+ ddp_find_unused_parameters=ddp_find_unused_parameters,
1922
2006
  **kwargs,
1923
2007
  )
1924
2008
 
@@ -1989,73 +2073,76 @@ class BaseMatchModel(BaseModel):
1989
2073
  scheduler_params: dict | None = None,
1990
2074
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
1991
2075
  loss_params: dict | list[dict] | None = None,
2076
+ loss_weights: int | float | list[int | float] | None = None,
2077
+ callbacks: list[Callback] | None = None,
1992
2078
  ):
1993
2079
  """
1994
- Compile match model with optimizer, scheduler, and loss function.
1995
- Mirrors BaseModel.compile while adding training_mode validation for match tasks.
2080
+ Configure the match model for training.
2081
+
2082
+ This mirrors `BaseModel.compile()` and additionally validates `training_mode`.
1996
2083
  """
1997
2084
  if self.training_mode not in self.support_training_modes:
1998
2085
  raise ValueError(
1999
2086
  f"{self.model_name} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
2000
2087
  )
2001
- # Call parent compile with match-specific logic
2002
- optimizer_params = optimizer_params or {}
2003
-
2004
- self.optimizer_name = (
2005
- optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
2006
- )
2007
- self.optimizer_params = optimizer_params
2008
- if isinstance(scheduler, str):
2009
- self.scheduler_name = scheduler
2010
- elif scheduler is not None:
2011
- # Try to get __name__ first (for class types), then __class__.__name__ (for instances)
2012
- self.scheduler_name = getattr(
2013
- scheduler,
2014
- "__name__",
2015
- getattr(scheduler.__class__, "__name__", str(scheduler)),
2016
- )
2017
- else:
2018
- self.scheduler_name = None
2019
- self.scheduler_params = scheduler_params or {}
2020
- self.loss_config = loss
2021
- self.loss_params = loss_params or {}
2022
2088
 
2023
- self.optimizer_fn = get_optimizer(
2024
- optimizer=optimizer, params=self.parameters(), **optimizer_params
2025
- )
2026
- # Set loss function based on training mode
2027
- default_losses = {
2089
+ default_loss_by_mode: dict[str, str] = {
2028
2090
  "pointwise": "bce",
2029
2091
  "pairwise": "bpr",
2030
2092
  "listwise": "sampled_softmax",
2031
2093
  }
2032
2094
 
2033
- if loss is None:
2034
- loss_value = default_losses.get(self.training_mode, "bce")
2035
- elif isinstance(loss, list):
2036
- loss_value = (
2037
- loss[0]
2038
- if loss and loss[0] is not None
2039
- else default_losses.get(self.training_mode, "bce")
2040
- )
2041
- else:
2042
- loss_value = loss
2043
-
2044
- # Pairwise/listwise modes do not support BCE, fall back to sensible defaults
2045
- if self.training_mode in {"pairwise", "listwise"} and loss_value in {
2046
- "bce",
2047
- "binary_crossentropy",
2048
- }:
2049
- loss_value = default_losses.get(self.training_mode, loss_value)
2050
- loss_kwargs = get_loss_kwargs(self.loss_params, 0)
2051
- self.loss_fn = [get_loss_fn(loss=loss_value, **loss_kwargs)]
2052
- # set scheduler
2053
- self.scheduler_fn = (
2054
- get_scheduler(scheduler, self.optimizer_fn, **(scheduler_params or {}))
2055
- if scheduler
2056
- else None
2095
+ effective_loss: str | nn.Module | list[str | nn.Module] | None = loss
2096
+ if effective_loss is None:
2097
+ effective_loss = default_loss_by_mode[self.training_mode]
2098
+ elif isinstance(effective_loss, (str,)):
2099
+ if self.training_mode in {"pairwise", "listwise"} and effective_loss in {
2100
+ "bce",
2101
+ "binary_crossentropy",
2102
+ }:
2103
+ effective_loss = default_loss_by_mode[self.training_mode]
2104
+ elif isinstance(effective_loss, list):
2105
+ if not effective_loss:
2106
+ effective_loss = [default_loss_by_mode[self.training_mode]]
2107
+ else:
2108
+ first = effective_loss[0]
2109
+ if (
2110
+ self.training_mode in {"pairwise", "listwise"}
2111
+ and isinstance(first, str)
2112
+ and first in {"bce", "binary_crossentropy"}
2113
+ ):
2114
+ effective_loss = [
2115
+ default_loss_by_mode[self.training_mode],
2116
+ *effective_loss[1:],
2117
+ ]
2118
+ return super().compile(
2119
+ optimizer=optimizer,
2120
+ optimizer_params=optimizer_params,
2121
+ scheduler=scheduler,
2122
+ scheduler_params=scheduler_params,
2123
+ loss=effective_loss,
2124
+ loss_params=loss_params,
2125
+ loss_weights=loss_weights,
2126
+ callbacks=callbacks,
2057
2127
  )
2058
2128
 
2129
+ def inbatch_logits(
2130
+ self, user_emb: torch.Tensor, item_emb: torch.Tensor
2131
+ ) -> torch.Tensor:
2132
+ if self.similarity_metric == "dot":
2133
+ logits = torch.matmul(user_emb, item_emb.t())
2134
+ elif self.similarity_metric == "cosine":
2135
+ user_norm = F.normalize(user_emb, p=2, dim=-1)
2136
+ item_norm = F.normalize(item_emb, p=2, dim=-1)
2137
+ logits = torch.matmul(user_norm, item_norm.t())
2138
+ elif self.similarity_metric == "euclidean":
2139
+ user_sq = (user_emb**2).sum(dim=1, keepdim=True) # [B, 1]
2140
+ item_sq = (item_emb**2).sum(dim=1, keepdim=True).t() # [1, B]
2141
+ logits = -(user_sq + item_sq - 2.0 * torch.matmul(user_emb, item_emb.t()))
2142
+ else:
2143
+ raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
2144
+ return logits / self.temperature
2145
+
2059
2146
  def compute_similarity(
2060
2147
  self, user_emb: torch.Tensor, item_emb: torch.Tensor
2061
2148
  ) -> torch.Tensor:
@@ -2125,9 +2212,7 @@ class BaseMatchModel(BaseModel):
2125
2212
 
2126
2213
  def compute_loss(self, y_pred, y_true):
2127
2214
  if self.training_mode == "pointwise":
2128
- if y_true is None:
2129
- return torch.tensor(0.0, device=self.device)
2130
- return self.loss_fn[0](y_pred, y_true)
2215
+ return super().compute_loss(y_pred, y_true)
2131
2216
 
2132
2217
  # pairwise / listwise using inbatch neg
2133
2218
  elif self.training_mode in ["pairwise", "listwise"]:
@@ -2136,14 +2221,39 @@ class BaseMatchModel(BaseModel):
2136
2221
  "For pairwise/listwise training, forward should return (user_emb, item_emb). Please check BaseMatchModel.forward implementation."
2137
2222
  )
2138
2223
  user_emb, item_emb = y_pred # [B, D], [B, D]
2139
- logits = torch.matmul(user_emb, item_emb.t()) # [B, B]
2140
- logits = logits / self.temperature
2141
- batch_size = logits.size(0)
2142
- targets = torch.arange(
2143
- batch_size, device=logits.device
2144
- ) # [0, 1, 2, ..., B-1]
2145
- # Cross-Entropy = InfoNCE
2146
- loss = F.cross_entropy(logits, targets)
2224
+ batch_size = user_emb.size(0)
2225
+ if batch_size < 2:
2226
+ return torch.tensor(0.0, device=user_emb.device)
2227
+
2228
+ logits = self.inbatch_logits(user_emb, item_emb) # [B, B]
2229
+
2230
+ eye = torch.eye(batch_size, device=logits.device, dtype=torch.bool)
2231
+ pos_logits = logits.diag() # [B]
2232
+ neg_logits = logits.masked_select(~eye).view(
2233
+ batch_size, batch_size - 1
2234
+ ) # [B, B-1]
2235
+
2236
+ loss_fn = self.loss_fn[0] if getattr(self, "loss_fn", None) else None
2237
+ if isinstance(loss_fn, SampledSoftmaxLoss):
2238
+ loss = loss_fn(pos_logits, neg_logits)
2239
+ elif isinstance(loss_fn, (BPRLoss, HingeLoss)):
2240
+ loss = loss_fn(pos_logits, neg_logits)
2241
+ elif isinstance(loss_fn, TripletLoss):
2242
+ neg_emb = item_emb.masked_select(~eye.unsqueeze(-1)).view(
2243
+ batch_size, batch_size - 1, item_emb.size(-1)
2244
+ )
2245
+ loss = loss_fn(user_emb, item_emb, neg_emb)
2246
+ elif isinstance(loss_fn, InfoNCELoss) and self.similarity_metric == "dot":
2247
+ neg_emb = item_emb.masked_select(~eye.unsqueeze(-1)).view(
2248
+ batch_size, batch_size - 1, item_emb.size(-1)
2249
+ )
2250
+ loss = loss_fn(user_emb, item_emb, neg_emb)
2251
+ else:
2252
+ targets = torch.arange(batch_size, device=logits.device)
2253
+ loss = F.cross_entropy(logits, targets)
2254
+
2255
+ if self.loss_weights is not None:
2256
+ loss *= float(self.loss_weights[0])
2147
2257
  return loss
2148
2258
  else:
2149
2259
  raise ValueError(f"Unknown training mode: {self.training_mode}")
@@ -2154,17 +2264,23 @@ class BaseMatchModel(BaseModel):
2154
2264
  """Prepare data loader for specific features."""
2155
2265
  if isinstance(data, DataLoader):
2156
2266
  return data
2157
-
2158
- feature_data = {}
2159
- for feature in features:
2160
- if isinstance(data, dict):
2161
- if feature.name in data:
2162
- feature_data[feature.name] = data[feature.name]
2163
- elif isinstance(data, pd.DataFrame):
2164
- if feature.name in data.columns:
2165
- feature_data[feature.name] = data[feature.name].values
2166
- return self.prepare_data_loader(
2167
- feature_data, batch_size=batch_size, shuffle=False
2267
+ tensors = build_tensors_from_data(
2268
+ data=data,
2269
+ raw_data=data,
2270
+ features=features,
2271
+ target_columns=[],
2272
+ id_columns=[],
2273
+ )
2274
+ if tensors is None:
2275
+ raise ValueError(
2276
+ "[BaseMatchModel-prepare_feature_data Error] No data available to create DataLoader."
2277
+ )
2278
+ dataset = TensorDictDataset(tensors)
2279
+ return DataLoader(
2280
+ dataset,
2281
+ batch_size=batch_size,
2282
+ shuffle=False,
2283
+ collate_fn=collate_fn,
2168
2284
  )
2169
2285
 
2170
2286
  def encode_user(