nextrec 0.4.5__py3-none-any.whl → 0.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/basic/model.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 05/12/2025
5
+ Checkpoint: edit on 18/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
@@ -25,7 +25,13 @@ from torch.utils.data import DataLoader
25
25
  from torch.utils.data.distributed import DistributedSampler
26
26
  from torch.nn.parallel import DistributedDataParallel as DDP
27
27
 
28
- from nextrec.basic.callback import EarlyStopper
28
+ from nextrec.basic.callback import (
29
+ EarlyStopper,
30
+ CallbackList,
31
+ Callback,
32
+ CheckpointSaver,
33
+ LearningRateScheduler,
34
+ )
29
35
  from nextrec.basic.features import (
30
36
  DenseFeature,
31
37
  SparseFeature,
@@ -42,7 +48,15 @@ from nextrec.data.dataloader import build_tensors_from_data
42
48
  from nextrec.data.batch_utils import collate_fn, batch_to_dict
43
49
  from nextrec.data.data_processing import get_column_data, get_user_ids
44
50
 
45
- from nextrec.loss import get_loss_fn, get_loss_kwargs
51
+ from nextrec.loss import (
52
+ BPRLoss,
53
+ HingeLoss,
54
+ InfoNCELoss,
55
+ SampledSoftmaxLoss,
56
+ TripletLoss,
57
+ get_loss_fn,
58
+ get_loss_kwargs,
59
+ )
46
60
  from nextrec.utils.tensor import to_tensor
47
61
  from nextrec.utils.device import configure_device
48
62
  from nextrec.utils.optimizer import get_optimizer, get_scheduler
@@ -71,13 +85,14 @@ class BaseModel(FeatureSet, nn.Module):
71
85
  target: list[str] | str | None = None,
72
86
  id_columns: list[str] | str | None = None,
73
87
  task: str | list[str] | None = None,
74
- device: str = "cpu",
75
- early_stop_patience: int = 20,
76
- session_id: str | None = None,
77
88
  embedding_l1_reg: float = 0.0,
78
89
  dense_l1_reg: float = 0.0,
79
90
  embedding_l2_reg: float = 0.0,
80
91
  dense_l2_reg: float = 0.0,
92
+ device: str = "cpu",
93
+ early_stop_patience: int = 20,
94
+ session_id: str | None = None,
95
+ callbacks: list[Callback] | None = None,
81
96
  distributed: bool = False,
82
97
  rank: int | None = None,
83
98
  world_size: int | None = None,
@@ -91,16 +106,20 @@ class BaseModel(FeatureSet, nn.Module):
91
106
  dense_features: DenseFeature definitions.
92
107
  sparse_features: SparseFeature definitions.
93
108
  sequence_features: SequenceFeature definitions.
94
- target: Target column name.
95
- id_columns: Identifier column name, only need to specify if GAUC is required.
109
+ target: Target column name. e.g., 'label' or ['label1', 'label2'].
110
+ id_columns: Identifier column name, only need to specify if GAUC is required. e.g., 'user_id'.
96
111
  task: Task types, e.g., 'binary', 'regression', or ['binary', 'regression']. If None, falls back to self.default_task.
97
- device: Torch device string or torch.device. e.g., 'cpu', 'cuda:0'.
112
+
98
113
  embedding_l1_reg: L1 regularization strength for embedding params. e.g., 1e-6.
99
114
  dense_l1_reg: L1 regularization strength for dense params. e.g., 1e-5.
100
115
  embedding_l2_reg: L2 regularization strength for embedding params. e.g., 1e-5.
101
116
  dense_l2_reg: L2 regularization strength for dense params. e.g., 1e-4.
117
+
118
+ device: Torch device string or torch.device. e.g., 'cpu', 'cuda:0'.
102
119
  early_stop_patience: Epochs for early stopping. 0 to disable. e.g., 20.
103
- session_id: Session id for logging. If None, a default id with timestamps will be created.
120
+ session_id: Session id for logging. If None, a default id with timestamps will be created. e.g., 'session_tutorial'.
121
+ callbacks: List of callback instances. If None, default callbacks will be created. e.g., [EarlyStopper(), CheckpointSaver()].
122
+
104
123
  distributed: Enable DistributedDataParallel flow, set True to enable distributed training.
105
124
  rank: Global rank (defaults to env RANK).
106
125
  world_size: Number of processes (defaults to env WORLD_SIZE).
@@ -126,11 +145,9 @@ class BaseModel(FeatureSet, nn.Module):
126
145
  self.session = create_session(session_id)
127
146
  self.session_path = self.session.root # pwd/session_id, path for this session
128
147
  self.checkpoint_path = os.path.join(
129
- self.session_path, self.model_name + "_checkpoint.model"
130
- ) # example: pwd/session_id/DeepFM_checkpoint.model
131
- self.best_path = os.path.join(
132
- self.session_path, self.model_name + "_best.model"
133
- )
148
+ self.session_path, self.model_name + "_checkpoint.pt"
149
+ ) # example: pwd/session_id/DeepFM_checkpoint.pt
150
+ self.best_path = os.path.join(self.session_path, self.model_name + "_best.pt")
134
151
  self.features_config_path = os.path.join(
135
152
  self.session_path, "features_config.pkl"
136
153
  )
@@ -153,6 +170,7 @@ class BaseModel(FeatureSet, nn.Module):
153
170
  self.max_gradient_norm = 1.0
154
171
  self.logger_initialized = False
155
172
  self.training_logger = None
173
+ self.callbacks = CallbackList(callbacks) if callbacks else CallbackList()
156
174
 
157
175
  def register_regularization_weights(
158
176
  self,
@@ -164,8 +182,22 @@ class BaseModel(FeatureSet, nn.Module):
164
182
  include_modules = include_modules or []
165
183
  embedding_layer = getattr(self, embedding_attr, None)
166
184
  embed_dict = getattr(embedding_layer, "embed_dict", None)
185
+ embedding_params: list[torch.Tensor] = []
167
186
  if embed_dict is not None:
168
- self.embedding_params.extend(embed.weight for embed in embed_dict.values())
187
+ embedding_params.extend(
188
+ embed.weight for embed in embed_dict.values() if hasattr(embed, "weight")
189
+ )
190
+ else:
191
+ weight = getattr(embedding_layer, "weight", None)
192
+ if isinstance(weight, torch.Tensor):
193
+ embedding_params.append(weight)
194
+
195
+ existing_embedding_ids = {id(param) for param in self.embedding_params}
196
+ for param in embedding_params:
197
+ if id(param) not in existing_embedding_ids:
198
+ self.embedding_params.append(param)
199
+ existing_embedding_ids.add(id(param))
200
+
169
201
  skip_types = (
170
202
  nn.BatchNorm1d,
171
203
  nn.BatchNorm2d,
@@ -174,6 +206,7 @@ class BaseModel(FeatureSet, nn.Module):
174
206
  nn.Dropout2d,
175
207
  nn.Dropout3d,
176
208
  )
209
+ existing_reg_ids = {id(param) for param in self.regularization_weights}
177
210
  for name, module in self.named_modules():
178
211
  if (
179
212
  module is self
@@ -184,7 +217,9 @@ class BaseModel(FeatureSet, nn.Module):
184
217
  ):
185
218
  continue
186
219
  if isinstance(module, nn.Linear):
187
- self.regularization_weights.append(module.weight)
220
+ if id(module.weight) not in existing_reg_ids:
221
+ self.regularization_weights.append(module.weight)
222
+ existing_reg_ids.add(id(module.weight))
188
223
 
189
224
  def add_reg_loss(self) -> torch.Tensor:
190
225
  reg_loss = torch.tensor(0.0, device=self.device)
@@ -337,6 +372,7 @@ class BaseModel(FeatureSet, nn.Module):
337
372
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
338
373
  loss_params: dict | list[dict] | None = None,
339
374
  loss_weights: int | float | list[int | float] | None = None,
375
+ callbacks: list[Callback] | None = None,
340
376
  ):
341
377
  """
342
378
  Configure the model for training.
@@ -348,6 +384,7 @@ class BaseModel(FeatureSet, nn.Module):
348
384
  loss: Loss function name, instance, or list for multi-task. e.g., 'bce', 'mse', or torch.nn.BCELoss(), you can also use custom loss functions.
349
385
  loss_params: Loss function parameters, or list for multi-task. e.g., {'weight': tensor([0.25, 0.75])}.
350
386
  loss_weights: Weights for each task loss, int/float for single-task or list for multi-task. e.g., 1.0, or [1.0, 0.5].
387
+ callbacks: Additional callbacks to add to the existing callback list. e.g., [EarlyStopper(), CheckpointSaver()].
351
388
  """
352
389
  if loss_params is None:
353
390
  self.loss_params = {}
@@ -429,6 +466,11 @@ class BaseModel(FeatureSet, nn.Module):
429
466
  )
430
467
  self.loss_weights = weights
431
468
 
469
+ # Add callbacks from compile if provided
470
+ if callbacks:
471
+ for callback in callbacks:
472
+ self.callbacks.append(callback)
473
+
432
474
  def compute_loss(self, y_pred, y_true):
433
475
  if y_true is None:
434
476
  raise ValueError(
@@ -582,6 +624,53 @@ class BaseModel(FeatureSet, nn.Module):
582
624
  task=self.task, metrics=metrics, target_names=self.target_columns
583
625
  )
584
626
  ) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
627
+
628
+ # Setup default callbacks if none exist
629
+ if len(self.callbacks.callbacks) == 0:
630
+ if self.nums_task == 1:
631
+ monitor_metric = f"val_{self.metrics[0]}"
632
+ else:
633
+ monitor_metric = f"val_{self.metrics[0]}_{self.target_columns[0]}"
634
+
635
+ if self.early_stop_patience > 0:
636
+ self.callbacks.append(
637
+ EarlyStopper(
638
+ monitor=monitor_metric,
639
+ patience=self.early_stop_patience,
640
+ mode=self.best_metrics_mode,
641
+ restore_best_weights=not self.distributed,
642
+ verbose=1 if self.is_main_process else 0,
643
+ )
644
+ )
645
+
646
+ if self.is_main_process:
647
+ self.callbacks.append(
648
+ CheckpointSaver(
649
+ save_path=self.best_path,
650
+ monitor=monitor_metric,
651
+ mode=self.best_metrics_mode,
652
+ save_best_only=True,
653
+ verbose=1,
654
+ )
655
+ )
656
+
657
+ if self.scheduler_fn is not None:
658
+ self.callbacks.append(
659
+ LearningRateScheduler(
660
+ scheduler=self.scheduler_fn,
661
+ verbose=1 if self.is_main_process else 0,
662
+ )
663
+ )
664
+
665
+ self.callbacks.set_model(self)
666
+ self.callbacks.set_params(
667
+ {
668
+ "epochs": epochs,
669
+ "batch_size": batch_size,
670
+ "metrics": self.metrics,
671
+ }
672
+ )
673
+
585
674
  self.early_stopper = EarlyStopper(
586
675
  patience=self.early_stop_patience, mode=self.best_metrics_mode
587
676
  )
@@ -650,7 +739,9 @@ class BaseModel(FeatureSet, nn.Module):
650
739
  else:
651
740
  train_loader = train_data
652
741
  else:
653
- loader, dataset = self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, return_dataset=True) # type: ignore
742
+ result = self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, return_dataset=True)
743
+ assert isinstance(result, tuple), "Expected tuple from prepare_data_loader with return_dataset=True"
744
+ loader, dataset = result
654
745
  if (
655
746
  auto_distributed_sampler
656
747
  and self.distributed
@@ -731,8 +822,13 @@ class BaseModel(FeatureSet, nn.Module):
731
822
  logging.info("")
732
823
  logging.info(colorize(f"Model device: {self.device}", bold=True))
733
824
 
825
+ self.callbacks.on_train_begin()
826
+
734
827
  for epoch in range(epochs):
735
828
  self.epoch_index = epoch
829
+
830
+ self.callbacks.on_epoch_begin(epoch)
831
+
736
832
  if is_streaming and self.is_main_process:
737
833
  logging.info("")
738
834
  logging.info(
@@ -742,10 +838,14 @@ class BaseModel(FeatureSet, nn.Module):
742
838
  # handle train result
743
839
  if (
744
840
  self.distributed
841
+ and isinstance(train_loader, DataLoader)
745
842
  and hasattr(train_loader, "sampler")
746
843
  and isinstance(train_loader.sampler, DistributedSampler)
747
844
  ):
748
845
  train_loader.sampler.set_epoch(epoch)
846
+ # Type guard: ensure train_loader is DataLoader for train_epoch
847
+ if not isinstance(train_loader, DataLoader):
848
+ raise TypeError(f"Expected DataLoader for training, got {type(train_loader)}")
749
849
  train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
750
850
  if isinstance(train_result, tuple): # [avg_loss, metrics_dict]
751
851
  train_loss, train_metrics = train_result
@@ -805,6 +905,9 @@ class BaseModel(FeatureSet, nn.Module):
805
905
  train_log_payload, step=epoch + 1, split="train"
806
906
  )
807
907
  if valid_loader is not None:
908
+ # Call on_validation_begin
909
+ self.callbacks.on_validation_begin()
910
+
808
911
  # pass user_ids only if needed for GAUC metric
809
912
  val_metrics = self.evaluate(
810
913
  valid_loader,
@@ -851,17 +954,17 @@ class BaseModel(FeatureSet, nn.Module):
851
954
  color="cyan",
852
955
  )
853
956
  )
957
+
958
+ # Call on_validation_end
959
+ self.callbacks.on_validation_end()
854
960
  if val_metrics and self.training_logger:
855
961
  self.training_logger.log_metrics(
856
962
  val_metrics, step=epoch + 1, split="valid"
857
963
  )
964
+
858
965
  # Handle empty validation metrics
859
966
  if not val_metrics:
860
967
  if self.is_main_process:
861
- self.save_model(
862
- self.checkpoint_path, add_timestamp=False, verbose=False
863
- )
864
- self.best_checkpoint_path = self.checkpoint_path
865
968
  logging.info(
866
969
  colorize(
867
970
  "Warning: No validation metrics computed. Skipping validation for this epoch.",
@@ -869,81 +972,26 @@ class BaseModel(FeatureSet, nn.Module):
869
972
  )
870
973
  )
871
974
  continue
872
- if self.nums_task == 1:
873
- primary_metric_key = self.metrics[0]
874
- else:
875
- primary_metric_key = f"{self.metrics[0]}_{self.target_columns[0]}"
876
- primary_metric = val_metrics.get(
877
- primary_metric_key, val_metrics[list(val_metrics.keys())[0]]
878
- ) # get primary metric value, default to first metric if not found
879
-
880
- # In distributed mode, broadcast primary_metric to ensure all processes use the same value
881
- if self.distributed and dist.is_available() and dist.is_initialized():
882
- metric_tensor = torch.tensor(
883
- [primary_metric], device=self.device, dtype=torch.float32
884
- )
885
- dist.broadcast(metric_tensor, src=0)
886
- primary_metric = float(metric_tensor.item())
887
-
888
- improved = False
889
- # early stopping check
890
- if self.best_metrics_mode == "max":
891
- if primary_metric > self.best_metric:
892
- self.best_metric = primary_metric
893
- improved = True
894
- else:
895
- if primary_metric < self.best_metric:
896
- self.best_metric = primary_metric
897
- improved = True
898
975
 
899
- # save checkpoint and best model for main process
900
- if self.is_main_process:
901
- self.save_model(
902
- self.checkpoint_path, add_timestamp=False, verbose=False
903
- )
904
- logging.info(" ")
905
- if improved:
906
- logging.info(
907
- colorize(
908
- f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"
909
- )
910
- )
911
- self.save_model(
912
- self.best_path, add_timestamp=False, verbose=False
913
- )
914
- self.best_checkpoint_path = self.best_path
915
- self.early_stopper.trial_counter = 0
916
- else:
917
- self.early_stopper.trial_counter += 1
918
- logging.info(
919
- colorize(
920
- f"No improvement for {self.early_stopper.trial_counter} epoch(s)"
921
- )
922
- )
923
- if self.early_stopper.trial_counter >= self.early_stopper.patience:
924
- self.stop_training = True
925
- logging.info(
926
- colorize(
927
- f"Early stopping triggered after {epoch + 1} epochs",
928
- color="bright_red",
929
- bold=True,
930
- )
931
- )
932
- else:
933
- # Non-main processes also update trial_counter to keep in sync
934
- if improved:
935
- self.early_stopper.trial_counter = 0
936
- else:
937
- self.early_stopper.trial_counter += 1
976
+ # Prepare epoch logs for callbacks
977
+ epoch_logs = {**train_log_payload}
978
+ if val_metrics:
979
+ # Add val_ prefix to validation metrics
980
+ for k, v in val_metrics.items():
981
+ epoch_logs[f"val_{k}"] = v
938
982
  else:
983
+ # No validation data
984
+ epoch_logs = {**train_log_payload}
939
985
  if self.is_main_process:
940
986
  self.save_model(
941
987
  self.checkpoint_path, add_timestamp=False, verbose=False
942
988
  )
943
- self.save_model(self.best_path, add_timestamp=False, verbose=False)
944
- self.best_checkpoint_path = self.best_path
989
+ self.best_checkpoint_path = self.checkpoint_path
945
990
 
946
- # Broadcast stop_training flag to all processes (always, regardless of validation)
991
+ # Call on_epoch_end for all callbacks (handles early stopping, checkpointing, lr scheduling)
992
+ self.callbacks.on_epoch_end(epoch, epoch_logs)
993
+
994
+ # Broadcast stop_training flag to all processes
947
995
  if self.distributed and dist.is_available() and dist.is_initialized():
948
996
  stop_tensor = torch.tensor(
949
997
  [int(self.stop_training)], device=self.device
@@ -953,14 +1001,9 @@ class BaseModel(FeatureSet, nn.Module):
953
1001
 
954
1002
  if self.stop_training:
955
1003
  break
956
- if self.scheduler_fn is not None:
957
- if isinstance(
958
- self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau
959
- ):
960
- if valid_loader is not None:
961
- self.scheduler_fn.step(primary_metric)
962
- else:
963
- self.scheduler_fn.step()
1004
+ # Call on_train_end for all callbacks
1005
+ self.callbacks.on_train_end()
1006
+
964
1007
  if self.distributed and dist.is_available() and dist.is_initialized():
965
1008
  dist.barrier() # dist.barrier() will wait for all processes, like async all_reduce()
966
1009
  if self.is_main_process:
@@ -972,9 +1015,17 @@ class BaseModel(FeatureSet, nn.Module):
972
1015
  logging.info(
973
1016
  colorize(f"Load best model from: {self.best_checkpoint_path}")
974
1017
  )
975
- self.load_model(
976
- self.best_checkpoint_path, map_location=self.device, verbose=False
977
- )
1018
+ if os.path.exists(self.best_checkpoint_path):
1019
+ self.load_model(
1020
+ self.best_checkpoint_path, map_location=self.device, verbose=False
1021
+ )
1022
+ elif self.is_main_process:
1023
+ logging.info(
1024
+ colorize(
1025
+ f"Warning: Best checkpoint not found at {self.best_checkpoint_path}, skip loading best model.",
1026
+ color="yellow",
1027
+ )
1028
+ )
978
1029
  if self.training_logger:
979
1030
  self.training_logger.close()
980
1031
  return self
@@ -1563,7 +1614,7 @@ class BaseModel(FeatureSet, nn.Module):
1563
1614
  path=save_path,
1564
1615
  default_dir=self.session_path,
1565
1616
  default_name=self.model_name,
1566
- suffix=".model",
1617
+ suffix=".pt",
1567
1618
  add_timestamp=add_timestamp,
1568
1619
  )
1569
1620
  model_path = Path(target_path)
@@ -1603,16 +1654,16 @@ class BaseModel(FeatureSet, nn.Module):
1603
1654
  self.to(self.device)
1604
1655
  base_path = Path(save_path)
1605
1656
  if base_path.is_dir():
1606
- model_files = sorted(base_path.glob("*.model"))
1657
+ model_files = sorted(base_path.glob("*.pt"))
1607
1658
  if not model_files:
1608
1659
  raise FileNotFoundError(
1609
- f"[BaseModel-load-model Error] No *.model file found in directory: {base_path}"
1660
+ f"[BaseModel-load-model Error] No *.pt file found in directory: {base_path}"
1610
1661
  )
1611
1662
  model_path = model_files[-1]
1612
1663
  config_dir = base_path
1613
1664
  else:
1614
1665
  model_path = (
1615
- base_path.with_suffix(".model") if base_path.suffix == "" else base_path
1666
+ base_path.with_suffix(".pt") if base_path.suffix == "" else base_path
1616
1667
  )
1617
1668
  config_dir = model_path.parent
1618
1669
  if not model_path.exists():
@@ -1665,21 +1716,21 @@ class BaseModel(FeatureSet, nn.Module):
1665
1716
  ) -> "BaseModel":
1666
1717
  """
1667
1718
  Load a model from a checkpoint path. The checkpoint path should contain:
1668
- a .model file and a features_config.pkl file.
1719
+ a .pt file and a features_config.pkl file.
1669
1720
  """
1670
1721
  base_path = Path(checkpoint_path)
1671
1722
  verbose = kwargs.pop("verbose", True)
1672
1723
  if base_path.is_dir():
1673
- model_candidates = sorted(base_path.glob("*.model"))
1724
+ model_candidates = sorted(base_path.glob("*.pt"))
1674
1725
  if not model_candidates:
1675
1726
  raise FileNotFoundError(
1676
- f"[BaseModel-from-checkpoint Error] No *.model file found under: {base_path}"
1727
+ f"[BaseModel-from-checkpoint Error] No *.pt file found under: {base_path}"
1677
1728
  )
1678
1729
  model_file = model_candidates[-1]
1679
1730
  config_dir = base_path
1680
1731
  else:
1681
1732
  model_file = (
1682
- base_path.with_suffix(".model") if base_path.suffix == "" else base_path
1733
+ base_path.with_suffix(".pt") if base_path.suffix == "" else base_path
1683
1734
  )
1684
1735
  config_dir = model_file.parent
1685
1736
  features_config_path = config_dir / "features_config.pkl"
@@ -1849,7 +1900,9 @@ class BaseModel(FeatureSet, nn.Module):
1849
1900
  class BaseMatchModel(BaseModel):
1850
1901
  """
1851
1902
  Base class for match (retrieval/recall) models
1852
- Supports pointwise, pairwise, and listwise training modes
1903
+
1904
+ - Pointwise: predicts a user-item match score/probability using labels (default target: 'label')
1905
+ - Pairwise/Listwise: trains with in-batch negatives; labels can be omitted by setting target=None
1853
1906
  """
1854
1907
 
1855
1908
  @property
@@ -1889,6 +1942,16 @@ class BaseMatchModel(BaseModel):
1889
1942
  embedding_l2_reg: float = 0.0,
1890
1943
  dense_l2_reg: float = 0.0,
1891
1944
  early_stop_patience: int = 20,
1945
+ target: list[str] | str | None = "label",
1946
+ id_columns: list[str] | str | None = None,
1947
+ task: str | list[str] | None = None,
1948
+ session_id: str | None = None,
1949
+ callbacks: list[Callback] | None = None,
1950
+ distributed: bool = False,
1951
+ rank: int | None = None,
1952
+ world_size: int | None = None,
1953
+ local_rank: int | None = None,
1954
+ ddp_find_unused_parameters: bool = False,
1892
1955
  **kwargs,
1893
1956
  ):
1894
1957
 
@@ -1913,14 +1976,22 @@ class BaseMatchModel(BaseModel):
1913
1976
  dense_features=all_dense_features,
1914
1977
  sparse_features=all_sparse_features,
1915
1978
  sequence_features=all_sequence_features,
1916
- target=["label"],
1917
- task="binary",
1979
+ target=target,
1980
+ id_columns=id_columns,
1981
+ task=task,
1918
1982
  device=device,
1919
1983
  embedding_l1_reg=embedding_l1_reg,
1920
1984
  dense_l1_reg=dense_l1_reg,
1921
1985
  embedding_l2_reg=embedding_l2_reg,
1922
1986
  dense_l2_reg=dense_l2_reg,
1923
1987
  early_stop_patience=early_stop_patience,
1988
+ session_id=session_id,
1989
+ callbacks=callbacks,
1990
+ distributed=distributed,
1991
+ rank=rank,
1992
+ world_size=world_size,
1993
+ local_rank=local_rank,
1994
+ ddp_find_unused_parameters=ddp_find_unused_parameters,
1924
1995
  **kwargs,
1925
1996
  )
1926
1997
 
@@ -1991,73 +2062,74 @@ class BaseMatchModel(BaseModel):
1991
2062
  scheduler_params: dict | None = None,
1992
2063
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
1993
2064
  loss_params: dict | list[dict] | None = None,
2065
+ loss_weights: int | float | list[int | float] | None = None,
2066
+ callbacks: list[Callback] | None = None,
1994
2067
  ):
1995
2068
  """
1996
- Compile match model with optimizer, scheduler, and loss function.
1997
- Mirrors BaseModel.compile while adding training_mode validation for match tasks.
2069
+ Configure the match model for training.
2070
+
2071
+ This mirrors `BaseModel.compile()` and additionally validates `training_mode`.
1998
2072
  """
1999
2073
  if self.training_mode not in self.support_training_modes:
2000
2074
  raise ValueError(
2001
2075
  f"{self.model_name} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
2002
2076
  )
2003
- # Call parent compile with match-specific logic
2004
- optimizer_params = optimizer_params or {}
2005
2077
 
2006
- self.optimizer_name = (
2007
- optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
2008
- )
2009
- self.optimizer_params = optimizer_params
2010
- if isinstance(scheduler, str):
2011
- self.scheduler_name = scheduler
2012
- elif scheduler is not None:
2013
- # Try to get __name__ first (for class types), then __class__.__name__ (for instances)
2014
- self.scheduler_name = getattr(
2015
- scheduler,
2016
- "__name__",
2017
- getattr(scheduler.__class__, "__name__", str(scheduler)),
2018
- )
2019
- else:
2020
- self.scheduler_name = None
2021
- self.scheduler_params = scheduler_params or {}
2022
- self.loss_config = loss
2023
- self.loss_params = loss_params or {}
2024
-
2025
- self.optimizer_fn = get_optimizer(
2026
- optimizer=optimizer, params=self.parameters(), **optimizer_params
2027
- )
2028
- # Set loss function based on training mode
2029
- default_losses = {
2078
+ default_loss_by_mode: dict[str, str] = {
2030
2079
  "pointwise": "bce",
2031
2080
  "pairwise": "bpr",
2032
2081
  "listwise": "sampled_softmax",
2033
2082
  }
2034
2083
 
2035
- if loss is None:
2036
- loss_value = default_losses.get(self.training_mode, "bce")
2037
- elif isinstance(loss, list):
2038
- loss_value = (
2039
- loss[0]
2040
- if loss and loss[0] is not None
2041
- else default_losses.get(self.training_mode, "bce")
2042
- )
2043
- else:
2044
- loss_value = loss
2045
-
2046
- # Pairwise/listwise modes do not support BCE, fall back to sensible defaults
2047
- if self.training_mode in {"pairwise", "listwise"} and loss_value in {
2048
- "bce",
2049
- "binary_crossentropy",
2050
- }:
2051
- loss_value = default_losses.get(self.training_mode, loss_value)
2052
- loss_kwargs = get_loss_kwargs(self.loss_params, 0)
2053
- self.loss_fn = [get_loss_fn(loss=loss_value, **loss_kwargs)]
2054
- # set scheduler
2055
- self.scheduler_fn = (
2056
- get_scheduler(scheduler, self.optimizer_fn, **(scheduler_params or {}))
2057
- if scheduler
2058
- else None
2084
+ effective_loss: str | nn.Module | list[str | nn.Module] | None = loss
2085
+ if effective_loss is None:
2086
+ effective_loss = default_loss_by_mode[self.training_mode]
2087
+ elif isinstance(effective_loss, (str,)):
2088
+ if (
2089
+ self.training_mode in {"pairwise", "listwise"}
2090
+ and effective_loss in {"bce", "binary_crossentropy"}
2091
+ ):
2092
+ effective_loss = default_loss_by_mode[self.training_mode]
2093
+ elif isinstance(effective_loss, list):
2094
+ if not effective_loss:
2095
+ effective_loss = [default_loss_by_mode[self.training_mode]]
2096
+ else:
2097
+ first = effective_loss[0]
2098
+ if (
2099
+ self.training_mode in {"pairwise", "listwise"}
2100
+ and isinstance(first, str)
2101
+ and first in {"bce", "binary_crossentropy"}
2102
+ ):
2103
+ effective_loss = [
2104
+ default_loss_by_mode[self.training_mode],
2105
+ *effective_loss[1:],
2106
+ ]
2107
+ return super().compile(
2108
+ optimizer=optimizer,
2109
+ optimizer_params=optimizer_params,
2110
+ scheduler=scheduler,
2111
+ scheduler_params=scheduler_params,
2112
+ loss=effective_loss,
2113
+ loss_params=loss_params,
2114
+ loss_weights=loss_weights,
2115
+ callbacks=callbacks,
2059
2116
  )
2060
2117
 
2118
+ def inbatch_logits(self, user_emb: torch.Tensor, item_emb: torch.Tensor) -> torch.Tensor:
2119
+ if self.similarity_metric == "dot":
2120
+ logits = torch.matmul(user_emb, item_emb.t())
2121
+ elif self.similarity_metric == "cosine":
2122
+ user_norm = F.normalize(user_emb, p=2, dim=-1)
2123
+ item_norm = F.normalize(item_emb, p=2, dim=-1)
2124
+ logits = torch.matmul(user_norm, item_norm.t())
2125
+ elif self.similarity_metric == "euclidean":
2126
+ user_sq = (user_emb**2).sum(dim=1, keepdim=True) # [B, 1]
2127
+ item_sq = (item_emb**2).sum(dim=1, keepdim=True).t() # [1, B]
2128
+ logits = -(user_sq + item_sq - 2.0 * torch.matmul(user_emb, item_emb.t()))
2129
+ else:
2130
+ raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
2131
+ return logits / self.temperature
2132
+
2061
2133
  def compute_similarity(
2062
2134
  self, user_emb: torch.Tensor, item_emb: torch.Tensor
2063
2135
  ) -> torch.Tensor:
@@ -2127,9 +2199,7 @@ class BaseMatchModel(BaseModel):
2127
2199
 
2128
2200
  def compute_loss(self, y_pred, y_true):
2129
2201
  if self.training_mode == "pointwise":
2130
- if y_true is None:
2131
- return torch.tensor(0.0, device=self.device)
2132
- return self.loss_fn[0](y_pred, y_true)
2202
+ return super().compute_loss(y_pred, y_true)
2133
2203
 
2134
2204
  # pairwise / listwise using inbatch neg
2135
2205
  elif self.training_mode in ["pairwise", "listwise"]:
@@ -2138,14 +2208,37 @@ class BaseMatchModel(BaseModel):
2138
2208
  "For pairwise/listwise training, forward should return (user_emb, item_emb). Please check BaseMatchModel.forward implementation."
2139
2209
  )
2140
2210
  user_emb, item_emb = y_pred # [B, D], [B, D]
2141
- logits = torch.matmul(user_emb, item_emb.t()) # [B, B]
2142
- logits = logits / self.temperature
2143
- batch_size = logits.size(0)
2144
- targets = torch.arange(
2145
- batch_size, device=logits.device
2146
- ) # [0, 1, 2, ..., B-1]
2147
- # Cross-Entropy = InfoNCE
2148
- loss = F.cross_entropy(logits, targets)
2211
+ batch_size = user_emb.size(0)
2212
+ if batch_size < 2:
2213
+ return torch.tensor(0.0, device=user_emb.device)
2214
+
2215
+ logits = self.inbatch_logits(user_emb, item_emb) # [B, B]
2216
+
2217
+ eye = torch.eye(batch_size, device=logits.device, dtype=torch.bool)
2218
+ pos_logits = logits.diag() # [B]
2219
+ neg_logits = logits.masked_select(~eye).view(batch_size, batch_size - 1) # [B, B-1]
2220
+
2221
+ loss_fn = self.loss_fn[0] if getattr(self, "loss_fn", None) else None
2222
+ if isinstance(loss_fn, SampledSoftmaxLoss):
2223
+ loss = loss_fn(pos_logits, neg_logits)
2224
+ elif isinstance(loss_fn, (BPRLoss, HingeLoss)):
2225
+ loss = loss_fn(pos_logits, neg_logits)
2226
+ elif isinstance(loss_fn, TripletLoss):
2227
+ neg_emb = item_emb.masked_select(~eye.unsqueeze(-1)).view(
2228
+ batch_size, batch_size - 1, item_emb.size(-1)
2229
+ )
2230
+ loss = loss_fn(user_emb, item_emb, neg_emb)
2231
+ elif isinstance(loss_fn, InfoNCELoss) and self.similarity_metric == "dot":
2232
+ neg_emb = item_emb.masked_select(~eye.unsqueeze(-1)).view(
2233
+ batch_size, batch_size - 1, item_emb.size(-1)
2234
+ )
2235
+ loss = loss_fn(user_emb, item_emb, neg_emb)
2236
+ else:
2237
+ targets = torch.arange(batch_size, device=logits.device)
2238
+ loss = F.cross_entropy(logits, targets)
2239
+
2240
+ if self.loss_weights is not None:
2241
+ loss *= float(self.loss_weights[0])
2149
2242
  return loss
2150
2243
  else:
2151
2244
  raise ValueError(f"Unknown training mode: {self.training_mode}")
@@ -2156,17 +2249,23 @@ class BaseMatchModel(BaseModel):
2156
2249
  """Prepare data loader for specific features."""
2157
2250
  if isinstance(data, DataLoader):
2158
2251
  return data
2159
-
2160
- feature_data = {}
2161
- for feature in features:
2162
- if isinstance(data, dict):
2163
- if feature.name in data:
2164
- feature_data[feature.name] = data[feature.name]
2165
- elif isinstance(data, pd.DataFrame):
2166
- if feature.name in data.columns:
2167
- feature_data[feature.name] = data[feature.name].values
2168
- return self.prepare_data_loader(
2169
- feature_data, batch_size=batch_size, shuffle=False
2252
+ tensors = build_tensors_from_data(
2253
+ data=data,
2254
+ raw_data=data,
2255
+ features=features,
2256
+ target_columns=[],
2257
+ id_columns=[],
2258
+ )
2259
+ if tensors is None:
2260
+ raise ValueError(
2261
+ "[BaseMatchModel-prepare_feature_data Error] No data available to create DataLoader."
2262
+ )
2263
+ dataset = TensorDictDataset(tensors)
2264
+ return DataLoader(
2265
+ dataset,
2266
+ batch_size=batch_size,
2267
+ shuffle=False,
2268
+ collate_fn=collate_fn,
2170
2269
  )
2171
2270
 
2172
2271
  def encode_user(