nextrec 0.4.6__py3-none-any.whl → 0.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/basic/model.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 05/12/2025
5
+ Checkpoint: edit on 18/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
@@ -25,7 +25,13 @@ from torch.utils.data import DataLoader
25
25
  from torch.utils.data.distributed import DistributedSampler
26
26
  from torch.nn.parallel import DistributedDataParallel as DDP
27
27
 
28
- from nextrec.basic.callback import EarlyStopper
28
+ from nextrec.basic.callback import (
29
+ EarlyStopper,
30
+ CallbackList,
31
+ Callback,
32
+ CheckpointSaver,
33
+ LearningRateScheduler,
34
+ )
29
35
  from nextrec.basic.features import (
30
36
  DenseFeature,
31
37
  SparseFeature,
@@ -42,7 +48,15 @@ from nextrec.data.dataloader import build_tensors_from_data
42
48
  from nextrec.data.batch_utils import collate_fn, batch_to_dict
43
49
  from nextrec.data.data_processing import get_column_data, get_user_ids
44
50
 
45
- from nextrec.loss import get_loss_fn, get_loss_kwargs
51
+ from nextrec.loss import (
52
+ BPRLoss,
53
+ HingeLoss,
54
+ InfoNCELoss,
55
+ SampledSoftmaxLoss,
56
+ TripletLoss,
57
+ get_loss_fn,
58
+ get_loss_kwargs,
59
+ )
46
60
  from nextrec.utils.tensor import to_tensor
47
61
  from nextrec.utils.device import configure_device
48
62
  from nextrec.utils.optimizer import get_optimizer, get_scheduler
@@ -71,13 +85,14 @@ class BaseModel(FeatureSet, nn.Module):
71
85
  target: list[str] | str | None = None,
72
86
  id_columns: list[str] | str | None = None,
73
87
  task: str | list[str] | None = None,
74
- device: str = "cpu",
75
- early_stop_patience: int = 20,
76
- session_id: str | None = None,
77
88
  embedding_l1_reg: float = 0.0,
78
89
  dense_l1_reg: float = 0.0,
79
90
  embedding_l2_reg: float = 0.0,
80
91
  dense_l2_reg: float = 0.0,
92
+ device: str = "cpu",
93
+ early_stop_patience: int = 20,
94
+ session_id: str | None = None,
95
+ callbacks: list[Callback] | None = None,
81
96
  distributed: bool = False,
82
97
  rank: int | None = None,
83
98
  world_size: int | None = None,
@@ -91,16 +106,20 @@ class BaseModel(FeatureSet, nn.Module):
91
106
  dense_features: DenseFeature definitions.
92
107
  sparse_features: SparseFeature definitions.
93
108
  sequence_features: SequenceFeature definitions.
94
- target: Target column name.
95
- id_columns: Identifier column name, only need to specify if GAUC is required.
109
+ target: Target column name. e.g., 'label' or ['label1', 'label2'].
110
+ id_columns: Identifier column name, only need to specify if GAUC is required. e.g., 'user_id'.
96
111
  task: Task types, e.g., 'binary', 'regression', or ['binary', 'regression']. If None, falls back to self.default_task.
97
- device: Torch device string or torch.device. e.g., 'cpu', 'cuda:0'.
112
+
98
113
  embedding_l1_reg: L1 regularization strength for embedding params. e.g., 1e-6.
99
114
  dense_l1_reg: L1 regularization strength for dense params. e.g., 1e-5.
100
115
  embedding_l2_reg: L2 regularization strength for embedding params. e.g., 1e-5.
101
116
  dense_l2_reg: L2 regularization strength for dense params. e.g., 1e-4.
117
+
118
+ device: Torch device string or torch.device. e.g., 'cpu', 'cuda:0'.
102
119
  early_stop_patience: Epochs for early stopping. 0 to disable. e.g., 20.
103
- session_id: Session id for logging. If None, a default id with timestamps will be created.
120
+ session_id: Session id for logging. If None, a default id with timestamps will be created. e.g., 'session_tutorial'.
121
+ callbacks: List of callback instances. If None, default callbacks will be created. e.g., [EarlyStopper(), CheckpointSaver()].
122
+
104
123
  distributed: Enable DistributedDataParallel flow, set True to enable distributed training.
105
124
  rank: Global rank (defaults to env RANK).
106
125
  world_size: Number of processes (defaults to env WORLD_SIZE).
@@ -151,6 +170,7 @@ class BaseModel(FeatureSet, nn.Module):
151
170
  self.max_gradient_norm = 1.0
152
171
  self.logger_initialized = False
153
172
  self.training_logger = None
173
+ self.callbacks = CallbackList(callbacks) if callbacks else CallbackList()
154
174
 
155
175
  def register_regularization_weights(
156
176
  self,
@@ -162,8 +182,22 @@ class BaseModel(FeatureSet, nn.Module):
162
182
  include_modules = include_modules or []
163
183
  embedding_layer = getattr(self, embedding_attr, None)
164
184
  embed_dict = getattr(embedding_layer, "embed_dict", None)
185
+ embedding_params: list[torch.Tensor] = []
165
186
  if embed_dict is not None:
166
- self.embedding_params.extend(embed.weight for embed in embed_dict.values())
187
+ embedding_params.extend(
188
+ embed.weight for embed in embed_dict.values() if hasattr(embed, "weight")
189
+ )
190
+ else:
191
+ weight = getattr(embedding_layer, "weight", None)
192
+ if isinstance(weight, torch.Tensor):
193
+ embedding_params.append(weight)
194
+
195
+ existing_embedding_ids = {id(param) for param in self.embedding_params}
196
+ for param in embedding_params:
197
+ if id(param) not in existing_embedding_ids:
198
+ self.embedding_params.append(param)
199
+ existing_embedding_ids.add(id(param))
200
+
167
201
  skip_types = (
168
202
  nn.BatchNorm1d,
169
203
  nn.BatchNorm2d,
@@ -172,6 +206,7 @@ class BaseModel(FeatureSet, nn.Module):
172
206
  nn.Dropout2d,
173
207
  nn.Dropout3d,
174
208
  )
209
+ existing_reg_ids = {id(param) for param in self.regularization_weights}
175
210
  for name, module in self.named_modules():
176
211
  if (
177
212
  module is self
@@ -182,7 +217,9 @@ class BaseModel(FeatureSet, nn.Module):
182
217
  ):
183
218
  continue
184
219
  if isinstance(module, nn.Linear):
185
- self.regularization_weights.append(module.weight)
220
+ if id(module.weight) not in existing_reg_ids:
221
+ self.regularization_weights.append(module.weight)
222
+ existing_reg_ids.add(id(module.weight))
186
223
 
187
224
  def add_reg_loss(self) -> torch.Tensor:
188
225
  reg_loss = torch.tensor(0.0, device=self.device)
@@ -335,6 +372,7 @@ class BaseModel(FeatureSet, nn.Module):
335
372
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
336
373
  loss_params: dict | list[dict] | None = None,
337
374
  loss_weights: int | float | list[int | float] | None = None,
375
+ callbacks: list[Callback] | None = None,
338
376
  ):
339
377
  """
340
378
  Configure the model for training.
@@ -346,6 +384,7 @@ class BaseModel(FeatureSet, nn.Module):
346
384
  loss: Loss function name, instance, or list for multi-task. e.g., 'bce', 'mse', or torch.nn.BCELoss(), you can also use custom loss functions.
347
385
  loss_params: Loss function parameters, or list for multi-task. e.g., {'weight': tensor([0.25, 0.75])}.
348
386
  loss_weights: Weights for each task loss, int/float for single-task or list for multi-task. e.g., 1.0, or [1.0, 0.5].
387
+ callbacks: Additional callbacks to add to the existing callback list. e.g., [EarlyStopper(), CheckpointSaver()].
349
388
  """
350
389
  if loss_params is None:
351
390
  self.loss_params = {}
@@ -427,6 +466,11 @@ class BaseModel(FeatureSet, nn.Module):
427
466
  )
428
467
  self.loss_weights = weights
429
468
 
469
+ # Add callbacks from compile if provided
470
+ if callbacks:
471
+ for callback in callbacks:
472
+ self.callbacks.append(callback)
473
+
430
474
  def compute_loss(self, y_pred, y_true):
431
475
  if y_true is None:
432
476
  raise ValueError(
@@ -580,6 +624,53 @@ class BaseModel(FeatureSet, nn.Module):
580
624
  task=self.task, metrics=metrics, target_names=self.target_columns
581
625
  )
582
626
  ) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
627
+
628
+ # Setup default callbacks if none exist
629
+ if len(self.callbacks.callbacks) == 0:
630
+ if self.nums_task == 1:
631
+ monitor_metric = f"val_{self.metrics[0]}"
632
+ else:
633
+ monitor_metric = f"val_{self.metrics[0]}_{self.target_columns[0]}"
634
+
635
+ if self.early_stop_patience > 0:
636
+ self.callbacks.append(
637
+ EarlyStopper(
638
+ monitor=monitor_metric,
639
+ patience=self.early_stop_patience,
640
+ mode=self.best_metrics_mode,
641
+ restore_best_weights=not self.distributed,
642
+ verbose=1 if self.is_main_process else 0,
643
+ )
644
+ )
645
+
646
+ if self.is_main_process:
647
+ self.callbacks.append(
648
+ CheckpointSaver(
649
+ save_path=self.best_path,
650
+ monitor=monitor_metric,
651
+ mode=self.best_metrics_mode,
652
+ save_best_only=True,
653
+ verbose=1,
654
+ )
655
+ )
656
+
657
+ if self.scheduler_fn is not None:
658
+ self.callbacks.append(
659
+ LearningRateScheduler(
660
+ scheduler=self.scheduler_fn,
661
+ verbose=1 if self.is_main_process else 0,
662
+ )
663
+ )
664
+
665
+ self.callbacks.set_model(self)
666
+ self.callbacks.set_params(
667
+ {
668
+ "epochs": epochs,
669
+ "batch_size": batch_size,
670
+ "metrics": self.metrics,
671
+ }
672
+ )
673
+
583
674
  self.early_stopper = EarlyStopper(
584
675
  patience=self.early_stop_patience, mode=self.best_metrics_mode
585
676
  )
@@ -648,7 +739,9 @@ class BaseModel(FeatureSet, nn.Module):
648
739
  else:
649
740
  train_loader = train_data
650
741
  else:
651
- loader, dataset = self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, return_dataset=True) # type: ignore
742
+ result = self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, return_dataset=True)
743
+ assert isinstance(result, tuple), "Expected tuple from prepare_data_loader with return_dataset=True"
744
+ loader, dataset = result
652
745
  if (
653
746
  auto_distributed_sampler
654
747
  and self.distributed
@@ -729,8 +822,13 @@ class BaseModel(FeatureSet, nn.Module):
729
822
  logging.info("")
730
823
  logging.info(colorize(f"Model device: {self.device}", bold=True))
731
824
 
825
+ self.callbacks.on_train_begin()
826
+
732
827
  for epoch in range(epochs):
733
828
  self.epoch_index = epoch
829
+
830
+ self.callbacks.on_epoch_begin(epoch)
831
+
734
832
  if is_streaming and self.is_main_process:
735
833
  logging.info("")
736
834
  logging.info(
@@ -740,10 +838,14 @@ class BaseModel(FeatureSet, nn.Module):
740
838
  # handle train result
741
839
  if (
742
840
  self.distributed
841
+ and isinstance(train_loader, DataLoader)
743
842
  and hasattr(train_loader, "sampler")
744
843
  and isinstance(train_loader.sampler, DistributedSampler)
745
844
  ):
746
845
  train_loader.sampler.set_epoch(epoch)
846
+ # Type guard: ensure train_loader is DataLoader for train_epoch
847
+ if not isinstance(train_loader, DataLoader):
848
+ raise TypeError(f"Expected DataLoader for training, got {type(train_loader)}")
747
849
  train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
748
850
  if isinstance(train_result, tuple): # [avg_loss, metrics_dict]
749
851
  train_loss, train_metrics = train_result
@@ -803,6 +905,9 @@ class BaseModel(FeatureSet, nn.Module):
803
905
  train_log_payload, step=epoch + 1, split="train"
804
906
  )
805
907
  if valid_loader is not None:
908
+ # Call on_validation_begin
909
+ self.callbacks.on_validation_begin()
910
+
806
911
  # pass user_ids only if needed for GAUC metric
807
912
  val_metrics = self.evaluate(
808
913
  valid_loader,
@@ -849,17 +954,17 @@ class BaseModel(FeatureSet, nn.Module):
849
954
  color="cyan",
850
955
  )
851
956
  )
957
+
958
+ # Call on_validation_end
959
+ self.callbacks.on_validation_end()
852
960
  if val_metrics and self.training_logger:
853
961
  self.training_logger.log_metrics(
854
962
  val_metrics, step=epoch + 1, split="valid"
855
963
  )
964
+
856
965
  # Handle empty validation metrics
857
966
  if not val_metrics:
858
967
  if self.is_main_process:
859
- self.save_model(
860
- self.checkpoint_path, add_timestamp=False, verbose=False
861
- )
862
- self.best_checkpoint_path = self.checkpoint_path
863
968
  logging.info(
864
969
  colorize(
865
970
  "Warning: No validation metrics computed. Skipping validation for this epoch.",
@@ -867,81 +972,26 @@ class BaseModel(FeatureSet, nn.Module):
867
972
  )
868
973
  )
869
974
  continue
870
- if self.nums_task == 1:
871
- primary_metric_key = self.metrics[0]
872
- else:
873
- primary_metric_key = f"{self.metrics[0]}_{self.target_columns[0]}"
874
- primary_metric = val_metrics.get(
875
- primary_metric_key, val_metrics[list(val_metrics.keys())[0]]
876
- ) # get primary metric value, default to first metric if not found
877
-
878
- # In distributed mode, broadcast primary_metric to ensure all processes use the same value
879
- if self.distributed and dist.is_available() and dist.is_initialized():
880
- metric_tensor = torch.tensor(
881
- [primary_metric], device=self.device, dtype=torch.float32
882
- )
883
- dist.broadcast(metric_tensor, src=0)
884
- primary_metric = float(metric_tensor.item())
885
-
886
- improved = False
887
- # early stopping check
888
- if self.best_metrics_mode == "max":
889
- if primary_metric > self.best_metric:
890
- self.best_metric = primary_metric
891
- improved = True
892
- else:
893
- if primary_metric < self.best_metric:
894
- self.best_metric = primary_metric
895
- improved = True
896
975
 
897
- # save checkpoint and best model for main process
898
- if self.is_main_process:
899
- self.save_model(
900
- self.checkpoint_path, add_timestamp=False, verbose=False
901
- )
902
- logging.info(" ")
903
- if improved:
904
- logging.info(
905
- colorize(
906
- f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"
907
- )
908
- )
909
- self.save_model(
910
- self.best_path, add_timestamp=False, verbose=False
911
- )
912
- self.best_checkpoint_path = self.best_path
913
- self.early_stopper.trial_counter = 0
914
- else:
915
- self.early_stopper.trial_counter += 1
916
- logging.info(
917
- colorize(
918
- f"No improvement for {self.early_stopper.trial_counter} epoch(s)"
919
- )
920
- )
921
- if self.early_stopper.trial_counter >= self.early_stopper.patience:
922
- self.stop_training = True
923
- logging.info(
924
- colorize(
925
- f"Early stopping triggered after {epoch + 1} epochs",
926
- color="bright_red",
927
- bold=True,
928
- )
929
- )
930
- else:
931
- # Non-main processes also update trial_counter to keep in sync
932
- if improved:
933
- self.early_stopper.trial_counter = 0
934
- else:
935
- self.early_stopper.trial_counter += 1
976
+ # Prepare epoch logs for callbacks
977
+ epoch_logs = {**train_log_payload}
978
+ if val_metrics:
979
+ # Add val_ prefix to validation metrics
980
+ for k, v in val_metrics.items():
981
+ epoch_logs[f"val_{k}"] = v
936
982
  else:
983
+ # No validation data
984
+ epoch_logs = {**train_log_payload}
937
985
  if self.is_main_process:
938
986
  self.save_model(
939
987
  self.checkpoint_path, add_timestamp=False, verbose=False
940
988
  )
941
- self.save_model(self.best_path, add_timestamp=False, verbose=False)
942
- self.best_checkpoint_path = self.best_path
989
+ self.best_checkpoint_path = self.checkpoint_path
943
990
 
944
- # Broadcast stop_training flag to all processes (always, regardless of validation)
991
+ # Call on_epoch_end for all callbacks (handles early stopping, checkpointing, lr scheduling)
992
+ self.callbacks.on_epoch_end(epoch, epoch_logs)
993
+
994
+ # Broadcast stop_training flag to all processes
945
995
  if self.distributed and dist.is_available() and dist.is_initialized():
946
996
  stop_tensor = torch.tensor(
947
997
  [int(self.stop_training)], device=self.device
@@ -951,14 +1001,9 @@ class BaseModel(FeatureSet, nn.Module):
951
1001
 
952
1002
  if self.stop_training:
953
1003
  break
954
- if self.scheduler_fn is not None:
955
- if isinstance(
956
- self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau
957
- ):
958
- if valid_loader is not None:
959
- self.scheduler_fn.step(primary_metric)
960
- else:
961
- self.scheduler_fn.step()
1004
+ # Call on_train_end for all callbacks
1005
+ self.callbacks.on_train_end()
1006
+
962
1007
  if self.distributed and dist.is_available() and dist.is_initialized():
963
1008
  dist.barrier() # dist.barrier() will wait for all processes, like async all_reduce()
964
1009
  if self.is_main_process:
@@ -970,9 +1015,17 @@ class BaseModel(FeatureSet, nn.Module):
970
1015
  logging.info(
971
1016
  colorize(f"Load best model from: {self.best_checkpoint_path}")
972
1017
  )
973
- self.load_model(
974
- self.best_checkpoint_path, map_location=self.device, verbose=False
975
- )
1018
+ if os.path.exists(self.best_checkpoint_path):
1019
+ self.load_model(
1020
+ self.best_checkpoint_path, map_location=self.device, verbose=False
1021
+ )
1022
+ elif self.is_main_process:
1023
+ logging.info(
1024
+ colorize(
1025
+ f"Warning: Best checkpoint not found at {self.best_checkpoint_path}, skip loading best model.",
1026
+ color="yellow",
1027
+ )
1028
+ )
976
1029
  if self.training_logger:
977
1030
  self.training_logger.close()
978
1031
  return self
@@ -1847,7 +1900,9 @@ class BaseModel(FeatureSet, nn.Module):
1847
1900
  class BaseMatchModel(BaseModel):
1848
1901
  """
1849
1902
  Base class for match (retrieval/recall) models
1850
- Supports pointwise, pairwise, and listwise training modes
1903
+
1904
+ - Pointwise: predicts a user-item match score/probability using labels (default target: 'label')
1905
+ - Pairwise/Listwise: trains with in-batch negatives; labels can be omitted by setting target=None
1851
1906
  """
1852
1907
 
1853
1908
  @property
@@ -1887,6 +1942,16 @@ class BaseMatchModel(BaseModel):
1887
1942
  embedding_l2_reg: float = 0.0,
1888
1943
  dense_l2_reg: float = 0.0,
1889
1944
  early_stop_patience: int = 20,
1945
+ target: list[str] | str | None = "label",
1946
+ id_columns: list[str] | str | None = None,
1947
+ task: str | list[str] | None = None,
1948
+ session_id: str | None = None,
1949
+ callbacks: list[Callback] | None = None,
1950
+ distributed: bool = False,
1951
+ rank: int | None = None,
1952
+ world_size: int | None = None,
1953
+ local_rank: int | None = None,
1954
+ ddp_find_unused_parameters: bool = False,
1890
1955
  **kwargs,
1891
1956
  ):
1892
1957
 
@@ -1911,14 +1976,22 @@ class BaseMatchModel(BaseModel):
1911
1976
  dense_features=all_dense_features,
1912
1977
  sparse_features=all_sparse_features,
1913
1978
  sequence_features=all_sequence_features,
1914
- target=["label"],
1915
- task="binary",
1979
+ target=target,
1980
+ id_columns=id_columns,
1981
+ task=task,
1916
1982
  device=device,
1917
1983
  embedding_l1_reg=embedding_l1_reg,
1918
1984
  dense_l1_reg=dense_l1_reg,
1919
1985
  embedding_l2_reg=embedding_l2_reg,
1920
1986
  dense_l2_reg=dense_l2_reg,
1921
1987
  early_stop_patience=early_stop_patience,
1988
+ session_id=session_id,
1989
+ callbacks=callbacks,
1990
+ distributed=distributed,
1991
+ rank=rank,
1992
+ world_size=world_size,
1993
+ local_rank=local_rank,
1994
+ ddp_find_unused_parameters=ddp_find_unused_parameters,
1922
1995
  **kwargs,
1923
1996
  )
1924
1997
 
@@ -1989,73 +2062,74 @@ class BaseMatchModel(BaseModel):
1989
2062
  scheduler_params: dict | None = None,
1990
2063
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
1991
2064
  loss_params: dict | list[dict] | None = None,
2065
+ loss_weights: int | float | list[int | float] | None = None,
2066
+ callbacks: list[Callback] | None = None,
1992
2067
  ):
1993
2068
  """
1994
- Compile match model with optimizer, scheduler, and loss function.
1995
- Mirrors BaseModel.compile while adding training_mode validation for match tasks.
2069
+ Configure the match model for training.
2070
+
2071
+ This mirrors `BaseModel.compile()` and additionally validates `training_mode`.
1996
2072
  """
1997
2073
  if self.training_mode not in self.support_training_modes:
1998
2074
  raise ValueError(
1999
2075
  f"{self.model_name} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
2000
2076
  )
2001
- # Call parent compile with match-specific logic
2002
- optimizer_params = optimizer_params or {}
2003
2077
 
2004
- self.optimizer_name = (
2005
- optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
2006
- )
2007
- self.optimizer_params = optimizer_params
2008
- if isinstance(scheduler, str):
2009
- self.scheduler_name = scheduler
2010
- elif scheduler is not None:
2011
- # Try to get __name__ first (for class types), then __class__.__name__ (for instances)
2012
- self.scheduler_name = getattr(
2013
- scheduler,
2014
- "__name__",
2015
- getattr(scheduler.__class__, "__name__", str(scheduler)),
2016
- )
2017
- else:
2018
- self.scheduler_name = None
2019
- self.scheduler_params = scheduler_params or {}
2020
- self.loss_config = loss
2021
- self.loss_params = loss_params or {}
2022
-
2023
- self.optimizer_fn = get_optimizer(
2024
- optimizer=optimizer, params=self.parameters(), **optimizer_params
2025
- )
2026
- # Set loss function based on training mode
2027
- default_losses = {
2078
+ default_loss_by_mode: dict[str, str] = {
2028
2079
  "pointwise": "bce",
2029
2080
  "pairwise": "bpr",
2030
2081
  "listwise": "sampled_softmax",
2031
2082
  }
2032
2083
 
2033
- if loss is None:
2034
- loss_value = default_losses.get(self.training_mode, "bce")
2035
- elif isinstance(loss, list):
2036
- loss_value = (
2037
- loss[0]
2038
- if loss and loss[0] is not None
2039
- else default_losses.get(self.training_mode, "bce")
2040
- )
2041
- else:
2042
- loss_value = loss
2043
-
2044
- # Pairwise/listwise modes do not support BCE, fall back to sensible defaults
2045
- if self.training_mode in {"pairwise", "listwise"} and loss_value in {
2046
- "bce",
2047
- "binary_crossentropy",
2048
- }:
2049
- loss_value = default_losses.get(self.training_mode, loss_value)
2050
- loss_kwargs = get_loss_kwargs(self.loss_params, 0)
2051
- self.loss_fn = [get_loss_fn(loss=loss_value, **loss_kwargs)]
2052
- # set scheduler
2053
- self.scheduler_fn = (
2054
- get_scheduler(scheduler, self.optimizer_fn, **(scheduler_params or {}))
2055
- if scheduler
2056
- else None
2084
+ effective_loss: str | nn.Module | list[str | nn.Module] | None = loss
2085
+ if effective_loss is None:
2086
+ effective_loss = default_loss_by_mode[self.training_mode]
2087
+ elif isinstance(effective_loss, (str,)):
2088
+ if (
2089
+ self.training_mode in {"pairwise", "listwise"}
2090
+ and effective_loss in {"bce", "binary_crossentropy"}
2091
+ ):
2092
+ effective_loss = default_loss_by_mode[self.training_mode]
2093
+ elif isinstance(effective_loss, list):
2094
+ if not effective_loss:
2095
+ effective_loss = [default_loss_by_mode[self.training_mode]]
2096
+ else:
2097
+ first = effective_loss[0]
2098
+ if (
2099
+ self.training_mode in {"pairwise", "listwise"}
2100
+ and isinstance(first, str)
2101
+ and first in {"bce", "binary_crossentropy"}
2102
+ ):
2103
+ effective_loss = [
2104
+ default_loss_by_mode[self.training_mode],
2105
+ *effective_loss[1:],
2106
+ ]
2107
+ return super().compile(
2108
+ optimizer=optimizer,
2109
+ optimizer_params=optimizer_params,
2110
+ scheduler=scheduler,
2111
+ scheduler_params=scheduler_params,
2112
+ loss=effective_loss,
2113
+ loss_params=loss_params,
2114
+ loss_weights=loss_weights,
2115
+ callbacks=callbacks,
2057
2116
  )
2058
2117
 
2118
+ def inbatch_logits(self, user_emb: torch.Tensor, item_emb: torch.Tensor) -> torch.Tensor:
2119
+ if self.similarity_metric == "dot":
2120
+ logits = torch.matmul(user_emb, item_emb.t())
2121
+ elif self.similarity_metric == "cosine":
2122
+ user_norm = F.normalize(user_emb, p=2, dim=-1)
2123
+ item_norm = F.normalize(item_emb, p=2, dim=-1)
2124
+ logits = torch.matmul(user_norm, item_norm.t())
2125
+ elif self.similarity_metric == "euclidean":
2126
+ user_sq = (user_emb**2).sum(dim=1, keepdim=True) # [B, 1]
2127
+ item_sq = (item_emb**2).sum(dim=1, keepdim=True).t() # [1, B]
2128
+ logits = -(user_sq + item_sq - 2.0 * torch.matmul(user_emb, item_emb.t()))
2129
+ else:
2130
+ raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
2131
+ return logits / self.temperature
2132
+
2059
2133
  def compute_similarity(
2060
2134
  self, user_emb: torch.Tensor, item_emb: torch.Tensor
2061
2135
  ) -> torch.Tensor:
@@ -2125,9 +2199,7 @@ class BaseMatchModel(BaseModel):
2125
2199
 
2126
2200
  def compute_loss(self, y_pred, y_true):
2127
2201
  if self.training_mode == "pointwise":
2128
- if y_true is None:
2129
- return torch.tensor(0.0, device=self.device)
2130
- return self.loss_fn[0](y_pred, y_true)
2202
+ return super().compute_loss(y_pred, y_true)
2131
2203
 
2132
2204
  # pairwise / listwise using inbatch neg
2133
2205
  elif self.training_mode in ["pairwise", "listwise"]:
@@ -2136,14 +2208,37 @@ class BaseMatchModel(BaseModel):
2136
2208
  "For pairwise/listwise training, forward should return (user_emb, item_emb). Please check BaseMatchModel.forward implementation."
2137
2209
  )
2138
2210
  user_emb, item_emb = y_pred # [B, D], [B, D]
2139
- logits = torch.matmul(user_emb, item_emb.t()) # [B, B]
2140
- logits = logits / self.temperature
2141
- batch_size = logits.size(0)
2142
- targets = torch.arange(
2143
- batch_size, device=logits.device
2144
- ) # [0, 1, 2, ..., B-1]
2145
- # Cross-Entropy = InfoNCE
2146
- loss = F.cross_entropy(logits, targets)
2211
+ batch_size = user_emb.size(0)
2212
+ if batch_size < 2:
2213
+ return torch.tensor(0.0, device=user_emb.device)
2214
+
2215
+ logits = self.inbatch_logits(user_emb, item_emb) # [B, B]
2216
+
2217
+ eye = torch.eye(batch_size, device=logits.device, dtype=torch.bool)
2218
+ pos_logits = logits.diag() # [B]
2219
+ neg_logits = logits.masked_select(~eye).view(batch_size, batch_size - 1) # [B, B-1]
2220
+
2221
+ loss_fn = self.loss_fn[0] if getattr(self, "loss_fn", None) else None
2222
+ if isinstance(loss_fn, SampledSoftmaxLoss):
2223
+ loss = loss_fn(pos_logits, neg_logits)
2224
+ elif isinstance(loss_fn, (BPRLoss, HingeLoss)):
2225
+ loss = loss_fn(pos_logits, neg_logits)
2226
+ elif isinstance(loss_fn, TripletLoss):
2227
+ neg_emb = item_emb.masked_select(~eye.unsqueeze(-1)).view(
2228
+ batch_size, batch_size - 1, item_emb.size(-1)
2229
+ )
2230
+ loss = loss_fn(user_emb, item_emb, neg_emb)
2231
+ elif isinstance(loss_fn, InfoNCELoss) and self.similarity_metric == "dot":
2232
+ neg_emb = item_emb.masked_select(~eye.unsqueeze(-1)).view(
2233
+ batch_size, batch_size - 1, item_emb.size(-1)
2234
+ )
2235
+ loss = loss_fn(user_emb, item_emb, neg_emb)
2236
+ else:
2237
+ targets = torch.arange(batch_size, device=logits.device)
2238
+ loss = F.cross_entropy(logits, targets)
2239
+
2240
+ if self.loss_weights is not None:
2241
+ loss *= float(self.loss_weights[0])
2147
2242
  return loss
2148
2243
  else:
2149
2244
  raise ValueError(f"Unknown training mode: {self.training_mode}")
@@ -2154,17 +2249,23 @@ class BaseMatchModel(BaseModel):
2154
2249
  """Prepare data loader for specific features."""
2155
2250
  if isinstance(data, DataLoader):
2156
2251
  return data
2157
-
2158
- feature_data = {}
2159
- for feature in features:
2160
- if isinstance(data, dict):
2161
- if feature.name in data:
2162
- feature_data[feature.name] = data[feature.name]
2163
- elif isinstance(data, pd.DataFrame):
2164
- if feature.name in data.columns:
2165
- feature_data[feature.name] = data[feature.name].values
2166
- return self.prepare_data_loader(
2167
- feature_data, batch_size=batch_size, shuffle=False
2252
+ tensors = build_tensors_from_data(
2253
+ data=data,
2254
+ raw_data=data,
2255
+ features=features,
2256
+ target_columns=[],
2257
+ id_columns=[],
2258
+ )
2259
+ if tensors is None:
2260
+ raise ValueError(
2261
+ "[BaseMatchModel-prepare_feature_data Error] No data available to create DataLoader."
2262
+ )
2263
+ dataset = TensorDictDataset(tensors)
2264
+ return DataLoader(
2265
+ dataset,
2266
+ batch_size=batch_size,
2267
+ shuffle=False,
2268
+ collate_fn=collate_fn,
2168
2269
  )
2169
2270
 
2170
2271
  def encode_user(