nextrec 0.4.6__py3-none-any.whl → 0.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/callback.py +399 -21
- nextrec/basic/model.py +274 -173
- nextrec/loss/loss_utils.py +73 -4
- nextrec/models/match/dssm.py +5 -4
- nextrec/models/match/dssm_v2.py +4 -3
- nextrec/models/match/mind.py +5 -4
- nextrec/models/match/sdm.py +5 -4
- nextrec/models/match/youtube_dnn.py +5 -4
- {nextrec-0.4.6.dist-info → nextrec-0.4.7.dist-info}/METADATA +30 -25
- {nextrec-0.4.6.dist-info → nextrec-0.4.7.dist-info}/RECORD +14 -14
- {nextrec-0.4.6.dist-info → nextrec-0.4.7.dist-info}/WHEEL +0 -0
- {nextrec-0.4.6.dist-info → nextrec-0.4.7.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.6.dist-info → nextrec-0.4.7.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 18/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -25,7 +25,13 @@ from torch.utils.data import DataLoader
|
|
|
25
25
|
from torch.utils.data.distributed import DistributedSampler
|
|
26
26
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
27
27
|
|
|
28
|
-
from nextrec.basic.callback import
|
|
28
|
+
from nextrec.basic.callback import (
|
|
29
|
+
EarlyStopper,
|
|
30
|
+
CallbackList,
|
|
31
|
+
Callback,
|
|
32
|
+
CheckpointSaver,
|
|
33
|
+
LearningRateScheduler,
|
|
34
|
+
)
|
|
29
35
|
from nextrec.basic.features import (
|
|
30
36
|
DenseFeature,
|
|
31
37
|
SparseFeature,
|
|
@@ -42,7 +48,15 @@ from nextrec.data.dataloader import build_tensors_from_data
|
|
|
42
48
|
from nextrec.data.batch_utils import collate_fn, batch_to_dict
|
|
43
49
|
from nextrec.data.data_processing import get_column_data, get_user_ids
|
|
44
50
|
|
|
45
|
-
from nextrec.loss import
|
|
51
|
+
from nextrec.loss import (
|
|
52
|
+
BPRLoss,
|
|
53
|
+
HingeLoss,
|
|
54
|
+
InfoNCELoss,
|
|
55
|
+
SampledSoftmaxLoss,
|
|
56
|
+
TripletLoss,
|
|
57
|
+
get_loss_fn,
|
|
58
|
+
get_loss_kwargs,
|
|
59
|
+
)
|
|
46
60
|
from nextrec.utils.tensor import to_tensor
|
|
47
61
|
from nextrec.utils.device import configure_device
|
|
48
62
|
from nextrec.utils.optimizer import get_optimizer, get_scheduler
|
|
@@ -71,13 +85,14 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
71
85
|
target: list[str] | str | None = None,
|
|
72
86
|
id_columns: list[str] | str | None = None,
|
|
73
87
|
task: str | list[str] | None = None,
|
|
74
|
-
device: str = "cpu",
|
|
75
|
-
early_stop_patience: int = 20,
|
|
76
|
-
session_id: str | None = None,
|
|
77
88
|
embedding_l1_reg: float = 0.0,
|
|
78
89
|
dense_l1_reg: float = 0.0,
|
|
79
90
|
embedding_l2_reg: float = 0.0,
|
|
80
91
|
dense_l2_reg: float = 0.0,
|
|
92
|
+
device: str = "cpu",
|
|
93
|
+
early_stop_patience: int = 20,
|
|
94
|
+
session_id: str | None = None,
|
|
95
|
+
callbacks: list[Callback] | None = None,
|
|
81
96
|
distributed: bool = False,
|
|
82
97
|
rank: int | None = None,
|
|
83
98
|
world_size: int | None = None,
|
|
@@ -91,16 +106,20 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
91
106
|
dense_features: DenseFeature definitions.
|
|
92
107
|
sparse_features: SparseFeature definitions.
|
|
93
108
|
sequence_features: SequenceFeature definitions.
|
|
94
|
-
target: Target column name.
|
|
95
|
-
id_columns: Identifier column name, only need to specify if GAUC is required.
|
|
109
|
+
target: Target column name. e.g., 'label' or ['label1', 'label2'].
|
|
110
|
+
id_columns: Identifier column name, only need to specify if GAUC is required. e.g., 'user_id'.
|
|
96
111
|
task: Task types, e.g., 'binary', 'regression', or ['binary', 'regression']. If None, falls back to self.default_task.
|
|
97
|
-
|
|
112
|
+
|
|
98
113
|
embedding_l1_reg: L1 regularization strength for embedding params. e.g., 1e-6.
|
|
99
114
|
dense_l1_reg: L1 regularization strength for dense params. e.g., 1e-5.
|
|
100
115
|
embedding_l2_reg: L2 regularization strength for embedding params. e.g., 1e-5.
|
|
101
116
|
dense_l2_reg: L2 regularization strength for dense params. e.g., 1e-4.
|
|
117
|
+
|
|
118
|
+
device: Torch device string or torch.device. e.g., 'cpu', 'cuda:0'.
|
|
102
119
|
early_stop_patience: Epochs for early stopping. 0 to disable. e.g., 20.
|
|
103
|
-
session_id: Session id for logging. If None, a default id with timestamps will be created.
|
|
120
|
+
session_id: Session id for logging. If None, a default id with timestamps will be created. e.g., 'session_tutorial'.
|
|
121
|
+
callbacks: List of callback instances. If None, default callbacks will be created. e.g., [EarlyStopper(), CheckpointSaver()].
|
|
122
|
+
|
|
104
123
|
distributed: Enable DistributedDataParallel flow, set True to enable distributed training.
|
|
105
124
|
rank: Global rank (defaults to env RANK).
|
|
106
125
|
world_size: Number of processes (defaults to env WORLD_SIZE).
|
|
@@ -151,6 +170,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
151
170
|
self.max_gradient_norm = 1.0
|
|
152
171
|
self.logger_initialized = False
|
|
153
172
|
self.training_logger = None
|
|
173
|
+
self.callbacks = CallbackList(callbacks) if callbacks else CallbackList()
|
|
154
174
|
|
|
155
175
|
def register_regularization_weights(
|
|
156
176
|
self,
|
|
@@ -162,8 +182,22 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
162
182
|
include_modules = include_modules or []
|
|
163
183
|
embedding_layer = getattr(self, embedding_attr, None)
|
|
164
184
|
embed_dict = getattr(embedding_layer, "embed_dict", None)
|
|
185
|
+
embedding_params: list[torch.Tensor] = []
|
|
165
186
|
if embed_dict is not None:
|
|
166
|
-
|
|
187
|
+
embedding_params.extend(
|
|
188
|
+
embed.weight for embed in embed_dict.values() if hasattr(embed, "weight")
|
|
189
|
+
)
|
|
190
|
+
else:
|
|
191
|
+
weight = getattr(embedding_layer, "weight", None)
|
|
192
|
+
if isinstance(weight, torch.Tensor):
|
|
193
|
+
embedding_params.append(weight)
|
|
194
|
+
|
|
195
|
+
existing_embedding_ids = {id(param) for param in self.embedding_params}
|
|
196
|
+
for param in embedding_params:
|
|
197
|
+
if id(param) not in existing_embedding_ids:
|
|
198
|
+
self.embedding_params.append(param)
|
|
199
|
+
existing_embedding_ids.add(id(param))
|
|
200
|
+
|
|
167
201
|
skip_types = (
|
|
168
202
|
nn.BatchNorm1d,
|
|
169
203
|
nn.BatchNorm2d,
|
|
@@ -172,6 +206,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
172
206
|
nn.Dropout2d,
|
|
173
207
|
nn.Dropout3d,
|
|
174
208
|
)
|
|
209
|
+
existing_reg_ids = {id(param) for param in self.regularization_weights}
|
|
175
210
|
for name, module in self.named_modules():
|
|
176
211
|
if (
|
|
177
212
|
module is self
|
|
@@ -182,7 +217,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
182
217
|
):
|
|
183
218
|
continue
|
|
184
219
|
if isinstance(module, nn.Linear):
|
|
185
|
-
|
|
220
|
+
if id(module.weight) not in existing_reg_ids:
|
|
221
|
+
self.regularization_weights.append(module.weight)
|
|
222
|
+
existing_reg_ids.add(id(module.weight))
|
|
186
223
|
|
|
187
224
|
def add_reg_loss(self) -> torch.Tensor:
|
|
188
225
|
reg_loss = torch.tensor(0.0, device=self.device)
|
|
@@ -335,6 +372,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
335
372
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
336
373
|
loss_params: dict | list[dict] | None = None,
|
|
337
374
|
loss_weights: int | float | list[int | float] | None = None,
|
|
375
|
+
callbacks: list[Callback] | None = None,
|
|
338
376
|
):
|
|
339
377
|
"""
|
|
340
378
|
Configure the model for training.
|
|
@@ -346,6 +384,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
346
384
|
loss: Loss function name, instance, or list for multi-task. e.g., 'bce', 'mse', or torch.nn.BCELoss(), you can also use custom loss functions.
|
|
347
385
|
loss_params: Loss function parameters, or list for multi-task. e.g., {'weight': tensor([0.25, 0.75])}.
|
|
348
386
|
loss_weights: Weights for each task loss, int/float for single-task or list for multi-task. e.g., 1.0, or [1.0, 0.5].
|
|
387
|
+
callbacks: Additional callbacks to add to the existing callback list. e.g., [EarlyStopper(), CheckpointSaver()].
|
|
349
388
|
"""
|
|
350
389
|
if loss_params is None:
|
|
351
390
|
self.loss_params = {}
|
|
@@ -427,6 +466,11 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
427
466
|
)
|
|
428
467
|
self.loss_weights = weights
|
|
429
468
|
|
|
469
|
+
# Add callbacks from compile if provided
|
|
470
|
+
if callbacks:
|
|
471
|
+
for callback in callbacks:
|
|
472
|
+
self.callbacks.append(callback)
|
|
473
|
+
|
|
430
474
|
def compute_loss(self, y_pred, y_true):
|
|
431
475
|
if y_true is None:
|
|
432
476
|
raise ValueError(
|
|
@@ -580,6 +624,53 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
580
624
|
task=self.task, metrics=metrics, target_names=self.target_columns
|
|
581
625
|
)
|
|
582
626
|
) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
|
|
627
|
+
|
|
628
|
+
# Setup default callbacks if none exist
|
|
629
|
+
if len(self.callbacks.callbacks) == 0:
|
|
630
|
+
if self.nums_task == 1:
|
|
631
|
+
monitor_metric = f"val_{self.metrics[0]}"
|
|
632
|
+
else:
|
|
633
|
+
monitor_metric = f"val_{self.metrics[0]}_{self.target_columns[0]}"
|
|
634
|
+
|
|
635
|
+
if self.early_stop_patience > 0:
|
|
636
|
+
self.callbacks.append(
|
|
637
|
+
EarlyStopper(
|
|
638
|
+
monitor=monitor_metric,
|
|
639
|
+
patience=self.early_stop_patience,
|
|
640
|
+
mode=self.best_metrics_mode,
|
|
641
|
+
restore_best_weights=not self.distributed,
|
|
642
|
+
verbose=1 if self.is_main_process else 0,
|
|
643
|
+
)
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
if self.is_main_process:
|
|
647
|
+
self.callbacks.append(
|
|
648
|
+
CheckpointSaver(
|
|
649
|
+
save_path=self.best_path,
|
|
650
|
+
monitor=monitor_metric,
|
|
651
|
+
mode=self.best_metrics_mode,
|
|
652
|
+
save_best_only=True,
|
|
653
|
+
verbose=1,
|
|
654
|
+
)
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
if self.scheduler_fn is not None:
|
|
658
|
+
self.callbacks.append(
|
|
659
|
+
LearningRateScheduler(
|
|
660
|
+
scheduler=self.scheduler_fn,
|
|
661
|
+
verbose=1 if self.is_main_process else 0,
|
|
662
|
+
)
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
self.callbacks.set_model(self)
|
|
666
|
+
self.callbacks.set_params(
|
|
667
|
+
{
|
|
668
|
+
"epochs": epochs,
|
|
669
|
+
"batch_size": batch_size,
|
|
670
|
+
"metrics": self.metrics,
|
|
671
|
+
}
|
|
672
|
+
)
|
|
673
|
+
|
|
583
674
|
self.early_stopper = EarlyStopper(
|
|
584
675
|
patience=self.early_stop_patience, mode=self.best_metrics_mode
|
|
585
676
|
)
|
|
@@ -648,7 +739,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
648
739
|
else:
|
|
649
740
|
train_loader = train_data
|
|
650
741
|
else:
|
|
651
|
-
|
|
742
|
+
result = self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, return_dataset=True)
|
|
743
|
+
assert isinstance(result, tuple), "Expected tuple from prepare_data_loader with return_dataset=True"
|
|
744
|
+
loader, dataset = result
|
|
652
745
|
if (
|
|
653
746
|
auto_distributed_sampler
|
|
654
747
|
and self.distributed
|
|
@@ -729,8 +822,13 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
729
822
|
logging.info("")
|
|
730
823
|
logging.info(colorize(f"Model device: {self.device}", bold=True))
|
|
731
824
|
|
|
825
|
+
self.callbacks.on_train_begin()
|
|
826
|
+
|
|
732
827
|
for epoch in range(epochs):
|
|
733
828
|
self.epoch_index = epoch
|
|
829
|
+
|
|
830
|
+
self.callbacks.on_epoch_begin(epoch)
|
|
831
|
+
|
|
734
832
|
if is_streaming and self.is_main_process:
|
|
735
833
|
logging.info("")
|
|
736
834
|
logging.info(
|
|
@@ -740,10 +838,14 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
740
838
|
# handle train result
|
|
741
839
|
if (
|
|
742
840
|
self.distributed
|
|
841
|
+
and isinstance(train_loader, DataLoader)
|
|
743
842
|
and hasattr(train_loader, "sampler")
|
|
744
843
|
and isinstance(train_loader.sampler, DistributedSampler)
|
|
745
844
|
):
|
|
746
845
|
train_loader.sampler.set_epoch(epoch)
|
|
846
|
+
# Type guard: ensure train_loader is DataLoader for train_epoch
|
|
847
|
+
if not isinstance(train_loader, DataLoader):
|
|
848
|
+
raise TypeError(f"Expected DataLoader for training, got {type(train_loader)}")
|
|
747
849
|
train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
|
|
748
850
|
if isinstance(train_result, tuple): # [avg_loss, metrics_dict]
|
|
749
851
|
train_loss, train_metrics = train_result
|
|
@@ -803,6 +905,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
803
905
|
train_log_payload, step=epoch + 1, split="train"
|
|
804
906
|
)
|
|
805
907
|
if valid_loader is not None:
|
|
908
|
+
# Call on_validation_begin
|
|
909
|
+
self.callbacks.on_validation_begin()
|
|
910
|
+
|
|
806
911
|
# pass user_ids only if needed for GAUC metric
|
|
807
912
|
val_metrics = self.evaluate(
|
|
808
913
|
valid_loader,
|
|
@@ -849,17 +954,17 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
849
954
|
color="cyan",
|
|
850
955
|
)
|
|
851
956
|
)
|
|
957
|
+
|
|
958
|
+
# Call on_validation_end
|
|
959
|
+
self.callbacks.on_validation_end()
|
|
852
960
|
if val_metrics and self.training_logger:
|
|
853
961
|
self.training_logger.log_metrics(
|
|
854
962
|
val_metrics, step=epoch + 1, split="valid"
|
|
855
963
|
)
|
|
964
|
+
|
|
856
965
|
# Handle empty validation metrics
|
|
857
966
|
if not val_metrics:
|
|
858
967
|
if self.is_main_process:
|
|
859
|
-
self.save_model(
|
|
860
|
-
self.checkpoint_path, add_timestamp=False, verbose=False
|
|
861
|
-
)
|
|
862
|
-
self.best_checkpoint_path = self.checkpoint_path
|
|
863
968
|
logging.info(
|
|
864
969
|
colorize(
|
|
865
970
|
"Warning: No validation metrics computed. Skipping validation for this epoch.",
|
|
@@ -867,81 +972,26 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
867
972
|
)
|
|
868
973
|
)
|
|
869
974
|
continue
|
|
870
|
-
if self.nums_task == 1:
|
|
871
|
-
primary_metric_key = self.metrics[0]
|
|
872
|
-
else:
|
|
873
|
-
primary_metric_key = f"{self.metrics[0]}_{self.target_columns[0]}"
|
|
874
|
-
primary_metric = val_metrics.get(
|
|
875
|
-
primary_metric_key, val_metrics[list(val_metrics.keys())[0]]
|
|
876
|
-
) # get primary metric value, default to first metric if not found
|
|
877
|
-
|
|
878
|
-
# In distributed mode, broadcast primary_metric to ensure all processes use the same value
|
|
879
|
-
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
880
|
-
metric_tensor = torch.tensor(
|
|
881
|
-
[primary_metric], device=self.device, dtype=torch.float32
|
|
882
|
-
)
|
|
883
|
-
dist.broadcast(metric_tensor, src=0)
|
|
884
|
-
primary_metric = float(metric_tensor.item())
|
|
885
|
-
|
|
886
|
-
improved = False
|
|
887
|
-
# early stopping check
|
|
888
|
-
if self.best_metrics_mode == "max":
|
|
889
|
-
if primary_metric > self.best_metric:
|
|
890
|
-
self.best_metric = primary_metric
|
|
891
|
-
improved = True
|
|
892
|
-
else:
|
|
893
|
-
if primary_metric < self.best_metric:
|
|
894
|
-
self.best_metric = primary_metric
|
|
895
|
-
improved = True
|
|
896
975
|
|
|
897
|
-
#
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
)
|
|
902
|
-
|
|
903
|
-
if improved:
|
|
904
|
-
logging.info(
|
|
905
|
-
colorize(
|
|
906
|
-
f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"
|
|
907
|
-
)
|
|
908
|
-
)
|
|
909
|
-
self.save_model(
|
|
910
|
-
self.best_path, add_timestamp=False, verbose=False
|
|
911
|
-
)
|
|
912
|
-
self.best_checkpoint_path = self.best_path
|
|
913
|
-
self.early_stopper.trial_counter = 0
|
|
914
|
-
else:
|
|
915
|
-
self.early_stopper.trial_counter += 1
|
|
916
|
-
logging.info(
|
|
917
|
-
colorize(
|
|
918
|
-
f"No improvement for {self.early_stopper.trial_counter} epoch(s)"
|
|
919
|
-
)
|
|
920
|
-
)
|
|
921
|
-
if self.early_stopper.trial_counter >= self.early_stopper.patience:
|
|
922
|
-
self.stop_training = True
|
|
923
|
-
logging.info(
|
|
924
|
-
colorize(
|
|
925
|
-
f"Early stopping triggered after {epoch + 1} epochs",
|
|
926
|
-
color="bright_red",
|
|
927
|
-
bold=True,
|
|
928
|
-
)
|
|
929
|
-
)
|
|
930
|
-
else:
|
|
931
|
-
# Non-main processes also update trial_counter to keep in sync
|
|
932
|
-
if improved:
|
|
933
|
-
self.early_stopper.trial_counter = 0
|
|
934
|
-
else:
|
|
935
|
-
self.early_stopper.trial_counter += 1
|
|
976
|
+
# Prepare epoch logs for callbacks
|
|
977
|
+
epoch_logs = {**train_log_payload}
|
|
978
|
+
if val_metrics:
|
|
979
|
+
# Add val_ prefix to validation metrics
|
|
980
|
+
for k, v in val_metrics.items():
|
|
981
|
+
epoch_logs[f"val_{k}"] = v
|
|
936
982
|
else:
|
|
983
|
+
# No validation data
|
|
984
|
+
epoch_logs = {**train_log_payload}
|
|
937
985
|
if self.is_main_process:
|
|
938
986
|
self.save_model(
|
|
939
987
|
self.checkpoint_path, add_timestamp=False, verbose=False
|
|
940
988
|
)
|
|
941
|
-
self.
|
|
942
|
-
self.best_checkpoint_path = self.best_path
|
|
989
|
+
self.best_checkpoint_path = self.checkpoint_path
|
|
943
990
|
|
|
944
|
-
#
|
|
991
|
+
# Call on_epoch_end for all callbacks (handles early stopping, checkpointing, lr scheduling)
|
|
992
|
+
self.callbacks.on_epoch_end(epoch, epoch_logs)
|
|
993
|
+
|
|
994
|
+
# Broadcast stop_training flag to all processes
|
|
945
995
|
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
946
996
|
stop_tensor = torch.tensor(
|
|
947
997
|
[int(self.stop_training)], device=self.device
|
|
@@ -951,14 +1001,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
951
1001
|
|
|
952
1002
|
if self.stop_training:
|
|
953
1003
|
break
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
):
|
|
958
|
-
if valid_loader is not None:
|
|
959
|
-
self.scheduler_fn.step(primary_metric)
|
|
960
|
-
else:
|
|
961
|
-
self.scheduler_fn.step()
|
|
1004
|
+
# Call on_train_end for all callbacks
|
|
1005
|
+
self.callbacks.on_train_end()
|
|
1006
|
+
|
|
962
1007
|
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
963
1008
|
dist.barrier() # dist.barrier() will wait for all processes, like async all_reduce()
|
|
964
1009
|
if self.is_main_process:
|
|
@@ -970,9 +1015,17 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
970
1015
|
logging.info(
|
|
971
1016
|
colorize(f"Load best model from: {self.best_checkpoint_path}")
|
|
972
1017
|
)
|
|
973
|
-
self.
|
|
974
|
-
self.
|
|
975
|
-
|
|
1018
|
+
if os.path.exists(self.best_checkpoint_path):
|
|
1019
|
+
self.load_model(
|
|
1020
|
+
self.best_checkpoint_path, map_location=self.device, verbose=False
|
|
1021
|
+
)
|
|
1022
|
+
elif self.is_main_process:
|
|
1023
|
+
logging.info(
|
|
1024
|
+
colorize(
|
|
1025
|
+
f"Warning: Best checkpoint not found at {self.best_checkpoint_path}, skip loading best model.",
|
|
1026
|
+
color="yellow",
|
|
1027
|
+
)
|
|
1028
|
+
)
|
|
976
1029
|
if self.training_logger:
|
|
977
1030
|
self.training_logger.close()
|
|
978
1031
|
return self
|
|
@@ -1847,7 +1900,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1847
1900
|
class BaseMatchModel(BaseModel):
|
|
1848
1901
|
"""
|
|
1849
1902
|
Base class for match (retrieval/recall) models
|
|
1850
|
-
|
|
1903
|
+
|
|
1904
|
+
- Pointwise: predicts a user-item match score/probability using labels (default target: 'label')
|
|
1905
|
+
- Pairwise/Listwise: trains with in-batch negatives; labels can be omitted by setting target=None
|
|
1851
1906
|
"""
|
|
1852
1907
|
|
|
1853
1908
|
@property
|
|
@@ -1887,6 +1942,16 @@ class BaseMatchModel(BaseModel):
|
|
|
1887
1942
|
embedding_l2_reg: float = 0.0,
|
|
1888
1943
|
dense_l2_reg: float = 0.0,
|
|
1889
1944
|
early_stop_patience: int = 20,
|
|
1945
|
+
target: list[str] | str | None = "label",
|
|
1946
|
+
id_columns: list[str] | str | None = None,
|
|
1947
|
+
task: str | list[str] | None = None,
|
|
1948
|
+
session_id: str | None = None,
|
|
1949
|
+
callbacks: list[Callback] | None = None,
|
|
1950
|
+
distributed: bool = False,
|
|
1951
|
+
rank: int | None = None,
|
|
1952
|
+
world_size: int | None = None,
|
|
1953
|
+
local_rank: int | None = None,
|
|
1954
|
+
ddp_find_unused_parameters: bool = False,
|
|
1890
1955
|
**kwargs,
|
|
1891
1956
|
):
|
|
1892
1957
|
|
|
@@ -1911,14 +1976,22 @@ class BaseMatchModel(BaseModel):
|
|
|
1911
1976
|
dense_features=all_dense_features,
|
|
1912
1977
|
sparse_features=all_sparse_features,
|
|
1913
1978
|
sequence_features=all_sequence_features,
|
|
1914
|
-
target=
|
|
1915
|
-
|
|
1979
|
+
target=target,
|
|
1980
|
+
id_columns=id_columns,
|
|
1981
|
+
task=task,
|
|
1916
1982
|
device=device,
|
|
1917
1983
|
embedding_l1_reg=embedding_l1_reg,
|
|
1918
1984
|
dense_l1_reg=dense_l1_reg,
|
|
1919
1985
|
embedding_l2_reg=embedding_l2_reg,
|
|
1920
1986
|
dense_l2_reg=dense_l2_reg,
|
|
1921
1987
|
early_stop_patience=early_stop_patience,
|
|
1988
|
+
session_id=session_id,
|
|
1989
|
+
callbacks=callbacks,
|
|
1990
|
+
distributed=distributed,
|
|
1991
|
+
rank=rank,
|
|
1992
|
+
world_size=world_size,
|
|
1993
|
+
local_rank=local_rank,
|
|
1994
|
+
ddp_find_unused_parameters=ddp_find_unused_parameters,
|
|
1922
1995
|
**kwargs,
|
|
1923
1996
|
)
|
|
1924
1997
|
|
|
@@ -1989,73 +2062,74 @@ class BaseMatchModel(BaseModel):
|
|
|
1989
2062
|
scheduler_params: dict | None = None,
|
|
1990
2063
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
1991
2064
|
loss_params: dict | list[dict] | None = None,
|
|
2065
|
+
loss_weights: int | float | list[int | float] | None = None,
|
|
2066
|
+
callbacks: list[Callback] | None = None,
|
|
1992
2067
|
):
|
|
1993
2068
|
"""
|
|
1994
|
-
|
|
1995
|
-
|
|
2069
|
+
Configure the match model for training.
|
|
2070
|
+
|
|
2071
|
+
This mirrors `BaseModel.compile()` and additionally validates `training_mode`.
|
|
1996
2072
|
"""
|
|
1997
2073
|
if self.training_mode not in self.support_training_modes:
|
|
1998
2074
|
raise ValueError(
|
|
1999
2075
|
f"{self.model_name} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
|
|
2000
2076
|
)
|
|
2001
|
-
# Call parent compile with match-specific logic
|
|
2002
|
-
optimizer_params = optimizer_params or {}
|
|
2003
2077
|
|
|
2004
|
-
|
|
2005
|
-
optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
2006
|
-
)
|
|
2007
|
-
self.optimizer_params = optimizer_params
|
|
2008
|
-
if isinstance(scheduler, str):
|
|
2009
|
-
self.scheduler_name = scheduler
|
|
2010
|
-
elif scheduler is not None:
|
|
2011
|
-
# Try to get __name__ first (for class types), then __class__.__name__ (for instances)
|
|
2012
|
-
self.scheduler_name = getattr(
|
|
2013
|
-
scheduler,
|
|
2014
|
-
"__name__",
|
|
2015
|
-
getattr(scheduler.__class__, "__name__", str(scheduler)),
|
|
2016
|
-
)
|
|
2017
|
-
else:
|
|
2018
|
-
self.scheduler_name = None
|
|
2019
|
-
self.scheduler_params = scheduler_params or {}
|
|
2020
|
-
self.loss_config = loss
|
|
2021
|
-
self.loss_params = loss_params or {}
|
|
2022
|
-
|
|
2023
|
-
self.optimizer_fn = get_optimizer(
|
|
2024
|
-
optimizer=optimizer, params=self.parameters(), **optimizer_params
|
|
2025
|
-
)
|
|
2026
|
-
# Set loss function based on training mode
|
|
2027
|
-
default_losses = {
|
|
2078
|
+
default_loss_by_mode: dict[str, str] = {
|
|
2028
2079
|
"pointwise": "bce",
|
|
2029
2080
|
"pairwise": "bpr",
|
|
2030
2081
|
"listwise": "sampled_softmax",
|
|
2031
2082
|
}
|
|
2032
2083
|
|
|
2033
|
-
|
|
2034
|
-
|
|
2035
|
-
|
|
2036
|
-
|
|
2037
|
-
|
|
2038
|
-
|
|
2039
|
-
|
|
2040
|
-
)
|
|
2041
|
-
|
|
2042
|
-
|
|
2043
|
-
|
|
2044
|
-
|
|
2045
|
-
|
|
2046
|
-
|
|
2047
|
-
|
|
2048
|
-
|
|
2049
|
-
|
|
2050
|
-
|
|
2051
|
-
|
|
2052
|
-
|
|
2053
|
-
|
|
2054
|
-
|
|
2055
|
-
|
|
2056
|
-
|
|
2084
|
+
effective_loss: str | nn.Module | list[str | nn.Module] | None = loss
|
|
2085
|
+
if effective_loss is None:
|
|
2086
|
+
effective_loss = default_loss_by_mode[self.training_mode]
|
|
2087
|
+
elif isinstance(effective_loss, (str,)):
|
|
2088
|
+
if (
|
|
2089
|
+
self.training_mode in {"pairwise", "listwise"}
|
|
2090
|
+
and effective_loss in {"bce", "binary_crossentropy"}
|
|
2091
|
+
):
|
|
2092
|
+
effective_loss = default_loss_by_mode[self.training_mode]
|
|
2093
|
+
elif isinstance(effective_loss, list):
|
|
2094
|
+
if not effective_loss:
|
|
2095
|
+
effective_loss = [default_loss_by_mode[self.training_mode]]
|
|
2096
|
+
else:
|
|
2097
|
+
first = effective_loss[0]
|
|
2098
|
+
if (
|
|
2099
|
+
self.training_mode in {"pairwise", "listwise"}
|
|
2100
|
+
and isinstance(first, str)
|
|
2101
|
+
and first in {"bce", "binary_crossentropy"}
|
|
2102
|
+
):
|
|
2103
|
+
effective_loss = [
|
|
2104
|
+
default_loss_by_mode[self.training_mode],
|
|
2105
|
+
*effective_loss[1:],
|
|
2106
|
+
]
|
|
2107
|
+
return super().compile(
|
|
2108
|
+
optimizer=optimizer,
|
|
2109
|
+
optimizer_params=optimizer_params,
|
|
2110
|
+
scheduler=scheduler,
|
|
2111
|
+
scheduler_params=scheduler_params,
|
|
2112
|
+
loss=effective_loss,
|
|
2113
|
+
loss_params=loss_params,
|
|
2114
|
+
loss_weights=loss_weights,
|
|
2115
|
+
callbacks=callbacks,
|
|
2057
2116
|
)
|
|
2058
2117
|
|
|
2118
|
+
def inbatch_logits(self, user_emb: torch.Tensor, item_emb: torch.Tensor) -> torch.Tensor:
|
|
2119
|
+
if self.similarity_metric == "dot":
|
|
2120
|
+
logits = torch.matmul(user_emb, item_emb.t())
|
|
2121
|
+
elif self.similarity_metric == "cosine":
|
|
2122
|
+
user_norm = F.normalize(user_emb, p=2, dim=-1)
|
|
2123
|
+
item_norm = F.normalize(item_emb, p=2, dim=-1)
|
|
2124
|
+
logits = torch.matmul(user_norm, item_norm.t())
|
|
2125
|
+
elif self.similarity_metric == "euclidean":
|
|
2126
|
+
user_sq = (user_emb**2).sum(dim=1, keepdim=True) # [B, 1]
|
|
2127
|
+
item_sq = (item_emb**2).sum(dim=1, keepdim=True).t() # [1, B]
|
|
2128
|
+
logits = -(user_sq + item_sq - 2.0 * torch.matmul(user_emb, item_emb.t()))
|
|
2129
|
+
else:
|
|
2130
|
+
raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
|
|
2131
|
+
return logits / self.temperature
|
|
2132
|
+
|
|
2059
2133
|
def compute_similarity(
|
|
2060
2134
|
self, user_emb: torch.Tensor, item_emb: torch.Tensor
|
|
2061
2135
|
) -> torch.Tensor:
|
|
@@ -2125,9 +2199,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2125
2199
|
|
|
2126
2200
|
def compute_loss(self, y_pred, y_true):
|
|
2127
2201
|
if self.training_mode == "pointwise":
|
|
2128
|
-
|
|
2129
|
-
return torch.tensor(0.0, device=self.device)
|
|
2130
|
-
return self.loss_fn[0](y_pred, y_true)
|
|
2202
|
+
return super().compute_loss(y_pred, y_true)
|
|
2131
2203
|
|
|
2132
2204
|
# pairwise / listwise using inbatch neg
|
|
2133
2205
|
elif self.training_mode in ["pairwise", "listwise"]:
|
|
@@ -2136,14 +2208,37 @@ class BaseMatchModel(BaseModel):
|
|
|
2136
2208
|
"For pairwise/listwise training, forward should return (user_emb, item_emb). Please check BaseMatchModel.forward implementation."
|
|
2137
2209
|
)
|
|
2138
2210
|
user_emb, item_emb = y_pred # [B, D], [B, D]
|
|
2139
|
-
|
|
2140
|
-
|
|
2141
|
-
|
|
2142
|
-
|
|
2143
|
-
|
|
2144
|
-
|
|
2145
|
-
|
|
2146
|
-
|
|
2211
|
+
batch_size = user_emb.size(0)
|
|
2212
|
+
if batch_size < 2:
|
|
2213
|
+
return torch.tensor(0.0, device=user_emb.device)
|
|
2214
|
+
|
|
2215
|
+
logits = self.inbatch_logits(user_emb, item_emb) # [B, B]
|
|
2216
|
+
|
|
2217
|
+
eye = torch.eye(batch_size, device=logits.device, dtype=torch.bool)
|
|
2218
|
+
pos_logits = logits.diag() # [B]
|
|
2219
|
+
neg_logits = logits.masked_select(~eye).view(batch_size, batch_size - 1) # [B, B-1]
|
|
2220
|
+
|
|
2221
|
+
loss_fn = self.loss_fn[0] if getattr(self, "loss_fn", None) else None
|
|
2222
|
+
if isinstance(loss_fn, SampledSoftmaxLoss):
|
|
2223
|
+
loss = loss_fn(pos_logits, neg_logits)
|
|
2224
|
+
elif isinstance(loss_fn, (BPRLoss, HingeLoss)):
|
|
2225
|
+
loss = loss_fn(pos_logits, neg_logits)
|
|
2226
|
+
elif isinstance(loss_fn, TripletLoss):
|
|
2227
|
+
neg_emb = item_emb.masked_select(~eye.unsqueeze(-1)).view(
|
|
2228
|
+
batch_size, batch_size - 1, item_emb.size(-1)
|
|
2229
|
+
)
|
|
2230
|
+
loss = loss_fn(user_emb, item_emb, neg_emb)
|
|
2231
|
+
elif isinstance(loss_fn, InfoNCELoss) and self.similarity_metric == "dot":
|
|
2232
|
+
neg_emb = item_emb.masked_select(~eye.unsqueeze(-1)).view(
|
|
2233
|
+
batch_size, batch_size - 1, item_emb.size(-1)
|
|
2234
|
+
)
|
|
2235
|
+
loss = loss_fn(user_emb, item_emb, neg_emb)
|
|
2236
|
+
else:
|
|
2237
|
+
targets = torch.arange(batch_size, device=logits.device)
|
|
2238
|
+
loss = F.cross_entropy(logits, targets)
|
|
2239
|
+
|
|
2240
|
+
if self.loss_weights is not None:
|
|
2241
|
+
loss *= float(self.loss_weights[0])
|
|
2147
2242
|
return loss
|
|
2148
2243
|
else:
|
|
2149
2244
|
raise ValueError(f"Unknown training mode: {self.training_mode}")
|
|
@@ -2154,17 +2249,23 @@ class BaseMatchModel(BaseModel):
|
|
|
2154
2249
|
"""Prepare data loader for specific features."""
|
|
2155
2250
|
if isinstance(data, DataLoader):
|
|
2156
2251
|
return data
|
|
2157
|
-
|
|
2158
|
-
|
|
2159
|
-
|
|
2160
|
-
|
|
2161
|
-
|
|
2162
|
-
|
|
2163
|
-
|
|
2164
|
-
|
|
2165
|
-
|
|
2166
|
-
|
|
2167
|
-
|
|
2252
|
+
tensors = build_tensors_from_data(
|
|
2253
|
+
data=data,
|
|
2254
|
+
raw_data=data,
|
|
2255
|
+
features=features,
|
|
2256
|
+
target_columns=[],
|
|
2257
|
+
id_columns=[],
|
|
2258
|
+
)
|
|
2259
|
+
if tensors is None:
|
|
2260
|
+
raise ValueError(
|
|
2261
|
+
"[BaseMatchModel-prepare_feature_data Error] No data available to create DataLoader."
|
|
2262
|
+
)
|
|
2263
|
+
dataset = TensorDictDataset(tensors)
|
|
2264
|
+
return DataLoader(
|
|
2265
|
+
dataset,
|
|
2266
|
+
batch_size=batch_size,
|
|
2267
|
+
shuffle=False,
|
|
2268
|
+
collate_fn=collate_fn,
|
|
2168
2269
|
)
|
|
2169
2270
|
|
|
2170
2271
|
def encode_user(
|