nextrec 0.4.6__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/callback.py +399 -21
- nextrec/basic/model.py +289 -173
- nextrec/cli.py +27 -1
- nextrec/loss/loss_utils.py +73 -4
- nextrec/models/match/dssm.py +5 -4
- nextrec/models/match/dssm_v2.py +4 -3
- nextrec/models/match/mind.py +5 -4
- nextrec/models/match/sdm.py +5 -4
- nextrec/models/match/youtube_dnn.py +5 -4
- nextrec/utils/cli_utils.py +58 -0
- nextrec/utils/config.py +5 -4
- {nextrec-0.4.6.dist-info → nextrec-0.4.8.dist-info}/METADATA +32 -26
- {nextrec-0.4.6.dist-info → nextrec-0.4.8.dist-info}/RECORD +17 -16
- {nextrec-0.4.6.dist-info → nextrec-0.4.8.dist-info}/WHEEL +0 -0
- {nextrec-0.4.6.dist-info → nextrec-0.4.8.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.6.dist-info → nextrec-0.4.8.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 18/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -25,7 +25,13 @@ from torch.utils.data import DataLoader
|
|
|
25
25
|
from torch.utils.data.distributed import DistributedSampler
|
|
26
26
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
27
27
|
|
|
28
|
-
from nextrec.basic.callback import
|
|
28
|
+
from nextrec.basic.callback import (
|
|
29
|
+
EarlyStopper,
|
|
30
|
+
CallbackList,
|
|
31
|
+
Callback,
|
|
32
|
+
CheckpointSaver,
|
|
33
|
+
LearningRateScheduler,
|
|
34
|
+
)
|
|
29
35
|
from nextrec.basic.features import (
|
|
30
36
|
DenseFeature,
|
|
31
37
|
SparseFeature,
|
|
@@ -42,7 +48,14 @@ from nextrec.data.dataloader import build_tensors_from_data
|
|
|
42
48
|
from nextrec.data.batch_utils import collate_fn, batch_to_dict
|
|
43
49
|
from nextrec.data.data_processing import get_column_data, get_user_ids
|
|
44
50
|
|
|
45
|
-
from nextrec.loss import
|
|
51
|
+
from nextrec.loss import (
|
|
52
|
+
BPRLoss,
|
|
53
|
+
HingeLoss,
|
|
54
|
+
InfoNCELoss,
|
|
55
|
+
SampledSoftmaxLoss,
|
|
56
|
+
TripletLoss,
|
|
57
|
+
get_loss_fn,
|
|
58
|
+
)
|
|
46
59
|
from nextrec.utils.tensor import to_tensor
|
|
47
60
|
from nextrec.utils.device import configure_device
|
|
48
61
|
from nextrec.utils.optimizer import get_optimizer, get_scheduler
|
|
@@ -71,13 +84,14 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
71
84
|
target: list[str] | str | None = None,
|
|
72
85
|
id_columns: list[str] | str | None = None,
|
|
73
86
|
task: str | list[str] | None = None,
|
|
74
|
-
device: str = "cpu",
|
|
75
|
-
early_stop_patience: int = 20,
|
|
76
|
-
session_id: str | None = None,
|
|
77
87
|
embedding_l1_reg: float = 0.0,
|
|
78
88
|
dense_l1_reg: float = 0.0,
|
|
79
89
|
embedding_l2_reg: float = 0.0,
|
|
80
90
|
dense_l2_reg: float = 0.0,
|
|
91
|
+
device: str = "cpu",
|
|
92
|
+
early_stop_patience: int = 20,
|
|
93
|
+
session_id: str | None = None,
|
|
94
|
+
callbacks: list[Callback] | None = None,
|
|
81
95
|
distributed: bool = False,
|
|
82
96
|
rank: int | None = None,
|
|
83
97
|
world_size: int | None = None,
|
|
@@ -91,16 +105,20 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
91
105
|
dense_features: DenseFeature definitions.
|
|
92
106
|
sparse_features: SparseFeature definitions.
|
|
93
107
|
sequence_features: SequenceFeature definitions.
|
|
94
|
-
target: Target column name.
|
|
95
|
-
id_columns: Identifier column name, only need to specify if GAUC is required.
|
|
108
|
+
target: Target column name. e.g., 'label' or ['label1', 'label2'].
|
|
109
|
+
id_columns: Identifier column name, only need to specify if GAUC is required. e.g., 'user_id'.
|
|
96
110
|
task: Task types, e.g., 'binary', 'regression', or ['binary', 'regression']. If None, falls back to self.default_task.
|
|
97
|
-
|
|
111
|
+
|
|
98
112
|
embedding_l1_reg: L1 regularization strength for embedding params. e.g., 1e-6.
|
|
99
113
|
dense_l1_reg: L1 regularization strength for dense params. e.g., 1e-5.
|
|
100
114
|
embedding_l2_reg: L2 regularization strength for embedding params. e.g., 1e-5.
|
|
101
115
|
dense_l2_reg: L2 regularization strength for dense params. e.g., 1e-4.
|
|
116
|
+
|
|
117
|
+
device: Torch device string or torch.device. e.g., 'cpu', 'cuda:0'.
|
|
102
118
|
early_stop_patience: Epochs for early stopping. 0 to disable. e.g., 20.
|
|
103
|
-
session_id: Session id for logging. If None, a default id with timestamps will be created.
|
|
119
|
+
session_id: Session id for logging. If None, a default id with timestamps will be created. e.g., 'session_tutorial'.
|
|
120
|
+
callbacks: List of callback instances. If None, default callbacks will be created. e.g., [EarlyStopper(), CheckpointSaver()].
|
|
121
|
+
|
|
104
122
|
distributed: Enable DistributedDataParallel flow, set True to enable distributed training.
|
|
105
123
|
rank: Global rank (defaults to env RANK).
|
|
106
124
|
world_size: Number of processes (defaults to env WORLD_SIZE).
|
|
@@ -151,6 +169,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
151
169
|
self.max_gradient_norm = 1.0
|
|
152
170
|
self.logger_initialized = False
|
|
153
171
|
self.training_logger = None
|
|
172
|
+
self.callbacks = CallbackList(callbacks) if callbacks else CallbackList()
|
|
154
173
|
|
|
155
174
|
def register_regularization_weights(
|
|
156
175
|
self,
|
|
@@ -162,8 +181,24 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
162
181
|
include_modules = include_modules or []
|
|
163
182
|
embedding_layer = getattr(self, embedding_attr, None)
|
|
164
183
|
embed_dict = getattr(embedding_layer, "embed_dict", None)
|
|
184
|
+
embedding_params: list[torch.Tensor] = []
|
|
165
185
|
if embed_dict is not None:
|
|
166
|
-
|
|
186
|
+
embedding_params.extend(
|
|
187
|
+
embed.weight
|
|
188
|
+
for embed in embed_dict.values()
|
|
189
|
+
if hasattr(embed, "weight")
|
|
190
|
+
)
|
|
191
|
+
else:
|
|
192
|
+
weight = getattr(embedding_layer, "weight", None)
|
|
193
|
+
if isinstance(weight, torch.Tensor):
|
|
194
|
+
embedding_params.append(weight)
|
|
195
|
+
|
|
196
|
+
existing_embedding_ids = {id(param) for param in self.embedding_params}
|
|
197
|
+
for param in embedding_params:
|
|
198
|
+
if id(param) not in existing_embedding_ids:
|
|
199
|
+
self.embedding_params.append(param)
|
|
200
|
+
existing_embedding_ids.add(id(param))
|
|
201
|
+
|
|
167
202
|
skip_types = (
|
|
168
203
|
nn.BatchNorm1d,
|
|
169
204
|
nn.BatchNorm2d,
|
|
@@ -172,6 +207,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
172
207
|
nn.Dropout2d,
|
|
173
208
|
nn.Dropout3d,
|
|
174
209
|
)
|
|
210
|
+
existing_reg_ids = {id(param) for param in self.regularization_weights}
|
|
175
211
|
for name, module in self.named_modules():
|
|
176
212
|
if (
|
|
177
213
|
module is self
|
|
@@ -182,7 +218,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
182
218
|
):
|
|
183
219
|
continue
|
|
184
220
|
if isinstance(module, nn.Linear):
|
|
185
|
-
|
|
221
|
+
if id(module.weight) not in existing_reg_ids:
|
|
222
|
+
self.regularization_weights.append(module.weight)
|
|
223
|
+
existing_reg_ids.add(id(module.weight))
|
|
186
224
|
|
|
187
225
|
def add_reg_loss(self) -> torch.Tensor:
|
|
188
226
|
reg_loss = torch.tensor(0.0, device=self.device)
|
|
@@ -335,6 +373,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
335
373
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
336
374
|
loss_params: dict | list[dict] | None = None,
|
|
337
375
|
loss_weights: int | float | list[int | float] | None = None,
|
|
376
|
+
callbacks: list[Callback] | None = None,
|
|
338
377
|
):
|
|
339
378
|
"""
|
|
340
379
|
Configure the model for training.
|
|
@@ -346,6 +385,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
346
385
|
loss: Loss function name, instance, or list for multi-task. e.g., 'bce', 'mse', or torch.nn.BCELoss(), you can also use custom loss functions.
|
|
347
386
|
loss_params: Loss function parameters, or list for multi-task. e.g., {'weight': tensor([0.25, 0.75])}.
|
|
348
387
|
loss_weights: Weights for each task loss, int/float for single-task or list for multi-task. e.g., 1.0, or [1.0, 0.5].
|
|
388
|
+
callbacks: Additional callbacks to add to the existing callback list. e.g., [EarlyStopper(), CheckpointSaver()].
|
|
349
389
|
"""
|
|
350
390
|
if loss_params is None:
|
|
351
391
|
self.loss_params = {}
|
|
@@ -427,6 +467,11 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
427
467
|
)
|
|
428
468
|
self.loss_weights = weights
|
|
429
469
|
|
|
470
|
+
# Add callbacks from compile if provided
|
|
471
|
+
if callbacks:
|
|
472
|
+
for callback in callbacks:
|
|
473
|
+
self.callbacks.append(callback)
|
|
474
|
+
|
|
430
475
|
def compute_loss(self, y_pred, y_true):
|
|
431
476
|
if y_true is None:
|
|
432
477
|
raise ValueError(
|
|
@@ -580,6 +625,53 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
580
625
|
task=self.task, metrics=metrics, target_names=self.target_columns
|
|
581
626
|
)
|
|
582
627
|
) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
|
|
628
|
+
|
|
629
|
+
# Setup default callbacks if none exist
|
|
630
|
+
if len(self.callbacks.callbacks) == 0:
|
|
631
|
+
if self.nums_task == 1:
|
|
632
|
+
monitor_metric = f"val_{self.metrics[0]}"
|
|
633
|
+
else:
|
|
634
|
+
monitor_metric = f"val_{self.metrics[0]}_{self.target_columns[0]}"
|
|
635
|
+
|
|
636
|
+
if self.early_stop_patience > 0:
|
|
637
|
+
self.callbacks.append(
|
|
638
|
+
EarlyStopper(
|
|
639
|
+
monitor=monitor_metric,
|
|
640
|
+
patience=self.early_stop_patience,
|
|
641
|
+
mode=self.best_metrics_mode,
|
|
642
|
+
restore_best_weights=not self.distributed,
|
|
643
|
+
verbose=1 if self.is_main_process else 0,
|
|
644
|
+
)
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
if self.is_main_process:
|
|
648
|
+
self.callbacks.append(
|
|
649
|
+
CheckpointSaver(
|
|
650
|
+
save_path=self.best_path,
|
|
651
|
+
monitor=monitor_metric,
|
|
652
|
+
mode=self.best_metrics_mode,
|
|
653
|
+
save_best_only=True,
|
|
654
|
+
verbose=1,
|
|
655
|
+
)
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
if self.scheduler_fn is not None:
|
|
659
|
+
self.callbacks.append(
|
|
660
|
+
LearningRateScheduler(
|
|
661
|
+
scheduler=self.scheduler_fn,
|
|
662
|
+
verbose=1 if self.is_main_process else 0,
|
|
663
|
+
)
|
|
664
|
+
)
|
|
665
|
+
|
|
666
|
+
self.callbacks.set_model(self)
|
|
667
|
+
self.callbacks.set_params(
|
|
668
|
+
{
|
|
669
|
+
"epochs": epochs,
|
|
670
|
+
"batch_size": batch_size,
|
|
671
|
+
"metrics": self.metrics,
|
|
672
|
+
}
|
|
673
|
+
)
|
|
674
|
+
|
|
583
675
|
self.early_stopper = EarlyStopper(
|
|
584
676
|
patience=self.early_stop_patience, mode=self.best_metrics_mode
|
|
585
677
|
)
|
|
@@ -648,7 +740,17 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
648
740
|
else:
|
|
649
741
|
train_loader = train_data
|
|
650
742
|
else:
|
|
651
|
-
|
|
743
|
+
result = self.prepare_data_loader(
|
|
744
|
+
train_data,
|
|
745
|
+
batch_size=batch_size,
|
|
746
|
+
shuffle=shuffle,
|
|
747
|
+
num_workers=num_workers,
|
|
748
|
+
return_dataset=True,
|
|
749
|
+
)
|
|
750
|
+
assert isinstance(
|
|
751
|
+
result, tuple
|
|
752
|
+
), "Expected tuple from prepare_data_loader with return_dataset=True"
|
|
753
|
+
loader, dataset = result
|
|
652
754
|
if (
|
|
653
755
|
auto_distributed_sampler
|
|
654
756
|
and self.distributed
|
|
@@ -729,8 +831,13 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
729
831
|
logging.info("")
|
|
730
832
|
logging.info(colorize(f"Model device: {self.device}", bold=True))
|
|
731
833
|
|
|
834
|
+
self.callbacks.on_train_begin()
|
|
835
|
+
|
|
732
836
|
for epoch in range(epochs):
|
|
733
837
|
self.epoch_index = epoch
|
|
838
|
+
|
|
839
|
+
self.callbacks.on_epoch_begin(epoch)
|
|
840
|
+
|
|
734
841
|
if is_streaming and self.is_main_process:
|
|
735
842
|
logging.info("")
|
|
736
843
|
logging.info(
|
|
@@ -740,10 +847,16 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
740
847
|
# handle train result
|
|
741
848
|
if (
|
|
742
849
|
self.distributed
|
|
850
|
+
and isinstance(train_loader, DataLoader)
|
|
743
851
|
and hasattr(train_loader, "sampler")
|
|
744
852
|
and isinstance(train_loader.sampler, DistributedSampler)
|
|
745
853
|
):
|
|
746
854
|
train_loader.sampler.set_epoch(epoch)
|
|
855
|
+
# Type guard: ensure train_loader is DataLoader for train_epoch
|
|
856
|
+
if not isinstance(train_loader, DataLoader):
|
|
857
|
+
raise TypeError(
|
|
858
|
+
f"Expected DataLoader for training, got {type(train_loader)}"
|
|
859
|
+
)
|
|
747
860
|
train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
|
|
748
861
|
if isinstance(train_result, tuple): # [avg_loss, metrics_dict]
|
|
749
862
|
train_loss, train_metrics = train_result
|
|
@@ -803,6 +916,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
803
916
|
train_log_payload, step=epoch + 1, split="train"
|
|
804
917
|
)
|
|
805
918
|
if valid_loader is not None:
|
|
919
|
+
# Call on_validation_begin
|
|
920
|
+
self.callbacks.on_validation_begin()
|
|
921
|
+
|
|
806
922
|
# pass user_ids only if needed for GAUC metric
|
|
807
923
|
val_metrics = self.evaluate(
|
|
808
924
|
valid_loader,
|
|
@@ -849,17 +965,17 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
849
965
|
color="cyan",
|
|
850
966
|
)
|
|
851
967
|
)
|
|
968
|
+
|
|
969
|
+
# Call on_validation_end
|
|
970
|
+
self.callbacks.on_validation_end()
|
|
852
971
|
if val_metrics and self.training_logger:
|
|
853
972
|
self.training_logger.log_metrics(
|
|
854
973
|
val_metrics, step=epoch + 1, split="valid"
|
|
855
974
|
)
|
|
975
|
+
|
|
856
976
|
# Handle empty validation metrics
|
|
857
977
|
if not val_metrics:
|
|
858
978
|
if self.is_main_process:
|
|
859
|
-
self.save_model(
|
|
860
|
-
self.checkpoint_path, add_timestamp=False, verbose=False
|
|
861
|
-
)
|
|
862
|
-
self.best_checkpoint_path = self.checkpoint_path
|
|
863
979
|
logging.info(
|
|
864
980
|
colorize(
|
|
865
981
|
"Warning: No validation metrics computed. Skipping validation for this epoch.",
|
|
@@ -867,81 +983,26 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
867
983
|
)
|
|
868
984
|
)
|
|
869
985
|
continue
|
|
870
|
-
if self.nums_task == 1:
|
|
871
|
-
primary_metric_key = self.metrics[0]
|
|
872
|
-
else:
|
|
873
|
-
primary_metric_key = f"{self.metrics[0]}_{self.target_columns[0]}"
|
|
874
|
-
primary_metric = val_metrics.get(
|
|
875
|
-
primary_metric_key, val_metrics[list(val_metrics.keys())[0]]
|
|
876
|
-
) # get primary metric value, default to first metric if not found
|
|
877
|
-
|
|
878
|
-
# In distributed mode, broadcast primary_metric to ensure all processes use the same value
|
|
879
|
-
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
880
|
-
metric_tensor = torch.tensor(
|
|
881
|
-
[primary_metric], device=self.device, dtype=torch.float32
|
|
882
|
-
)
|
|
883
|
-
dist.broadcast(metric_tensor, src=0)
|
|
884
|
-
primary_metric = float(metric_tensor.item())
|
|
885
|
-
|
|
886
|
-
improved = False
|
|
887
|
-
# early stopping check
|
|
888
|
-
if self.best_metrics_mode == "max":
|
|
889
|
-
if primary_metric > self.best_metric:
|
|
890
|
-
self.best_metric = primary_metric
|
|
891
|
-
improved = True
|
|
892
|
-
else:
|
|
893
|
-
if primary_metric < self.best_metric:
|
|
894
|
-
self.best_metric = primary_metric
|
|
895
|
-
improved = True
|
|
896
986
|
|
|
897
|
-
#
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
)
|
|
902
|
-
|
|
903
|
-
if improved:
|
|
904
|
-
logging.info(
|
|
905
|
-
colorize(
|
|
906
|
-
f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"
|
|
907
|
-
)
|
|
908
|
-
)
|
|
909
|
-
self.save_model(
|
|
910
|
-
self.best_path, add_timestamp=False, verbose=False
|
|
911
|
-
)
|
|
912
|
-
self.best_checkpoint_path = self.best_path
|
|
913
|
-
self.early_stopper.trial_counter = 0
|
|
914
|
-
else:
|
|
915
|
-
self.early_stopper.trial_counter += 1
|
|
916
|
-
logging.info(
|
|
917
|
-
colorize(
|
|
918
|
-
f"No improvement for {self.early_stopper.trial_counter} epoch(s)"
|
|
919
|
-
)
|
|
920
|
-
)
|
|
921
|
-
if self.early_stopper.trial_counter >= self.early_stopper.patience:
|
|
922
|
-
self.stop_training = True
|
|
923
|
-
logging.info(
|
|
924
|
-
colorize(
|
|
925
|
-
f"Early stopping triggered after {epoch + 1} epochs",
|
|
926
|
-
color="bright_red",
|
|
927
|
-
bold=True,
|
|
928
|
-
)
|
|
929
|
-
)
|
|
930
|
-
else:
|
|
931
|
-
# Non-main processes also update trial_counter to keep in sync
|
|
932
|
-
if improved:
|
|
933
|
-
self.early_stopper.trial_counter = 0
|
|
934
|
-
else:
|
|
935
|
-
self.early_stopper.trial_counter += 1
|
|
987
|
+
# Prepare epoch logs for callbacks
|
|
988
|
+
epoch_logs = {**train_log_payload}
|
|
989
|
+
if val_metrics:
|
|
990
|
+
# Add val_ prefix to validation metrics
|
|
991
|
+
for k, v in val_metrics.items():
|
|
992
|
+
epoch_logs[f"val_{k}"] = v
|
|
936
993
|
else:
|
|
994
|
+
# No validation data
|
|
995
|
+
epoch_logs = {**train_log_payload}
|
|
937
996
|
if self.is_main_process:
|
|
938
997
|
self.save_model(
|
|
939
998
|
self.checkpoint_path, add_timestamp=False, verbose=False
|
|
940
999
|
)
|
|
941
|
-
self.
|
|
942
|
-
self.best_checkpoint_path = self.best_path
|
|
1000
|
+
self.best_checkpoint_path = self.checkpoint_path
|
|
943
1001
|
|
|
944
|
-
#
|
|
1002
|
+
# Call on_epoch_end for all callbacks (handles early stopping, checkpointing, lr scheduling)
|
|
1003
|
+
self.callbacks.on_epoch_end(epoch, epoch_logs)
|
|
1004
|
+
|
|
1005
|
+
# Broadcast stop_training flag to all processes
|
|
945
1006
|
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
946
1007
|
stop_tensor = torch.tensor(
|
|
947
1008
|
[int(self.stop_training)], device=self.device
|
|
@@ -951,14 +1012,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
951
1012
|
|
|
952
1013
|
if self.stop_training:
|
|
953
1014
|
break
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
):
|
|
958
|
-
if valid_loader is not None:
|
|
959
|
-
self.scheduler_fn.step(primary_metric)
|
|
960
|
-
else:
|
|
961
|
-
self.scheduler_fn.step()
|
|
1015
|
+
# Call on_train_end for all callbacks
|
|
1016
|
+
self.callbacks.on_train_end()
|
|
1017
|
+
|
|
962
1018
|
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
963
1019
|
dist.barrier() # dist.barrier() will wait for all processes, like async all_reduce()
|
|
964
1020
|
if self.is_main_process:
|
|
@@ -970,9 +1026,17 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
970
1026
|
logging.info(
|
|
971
1027
|
colorize(f"Load best model from: {self.best_checkpoint_path}")
|
|
972
1028
|
)
|
|
973
|
-
self.
|
|
974
|
-
self.
|
|
975
|
-
|
|
1029
|
+
if os.path.exists(self.best_checkpoint_path):
|
|
1030
|
+
self.load_model(
|
|
1031
|
+
self.best_checkpoint_path, map_location=self.device, verbose=False
|
|
1032
|
+
)
|
|
1033
|
+
elif self.is_main_process:
|
|
1034
|
+
logging.info(
|
|
1035
|
+
colorize(
|
|
1036
|
+
f"Warning: Best checkpoint not found at {self.best_checkpoint_path}, skip loading best model.",
|
|
1037
|
+
color="yellow",
|
|
1038
|
+
)
|
|
1039
|
+
)
|
|
976
1040
|
if self.training_logger:
|
|
977
1041
|
self.training_logger.close()
|
|
978
1042
|
return self
|
|
@@ -1847,7 +1911,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1847
1911
|
class BaseMatchModel(BaseModel):
|
|
1848
1912
|
"""
|
|
1849
1913
|
Base class for match (retrieval/recall) models
|
|
1850
|
-
|
|
1914
|
+
|
|
1915
|
+
- Pointwise: predicts a user-item match score/probability using labels (default target: 'label')
|
|
1916
|
+
- Pairwise/Listwise: trains with in-batch negatives; labels can be omitted by setting target=None
|
|
1851
1917
|
"""
|
|
1852
1918
|
|
|
1853
1919
|
@property
|
|
@@ -1887,6 +1953,16 @@ class BaseMatchModel(BaseModel):
|
|
|
1887
1953
|
embedding_l2_reg: float = 0.0,
|
|
1888
1954
|
dense_l2_reg: float = 0.0,
|
|
1889
1955
|
early_stop_patience: int = 20,
|
|
1956
|
+
target: list[str] | str | None = "label",
|
|
1957
|
+
id_columns: list[str] | str | None = None,
|
|
1958
|
+
task: str | list[str] | None = None,
|
|
1959
|
+
session_id: str | None = None,
|
|
1960
|
+
callbacks: list[Callback] | None = None,
|
|
1961
|
+
distributed: bool = False,
|
|
1962
|
+
rank: int | None = None,
|
|
1963
|
+
world_size: int | None = None,
|
|
1964
|
+
local_rank: int | None = None,
|
|
1965
|
+
ddp_find_unused_parameters: bool = False,
|
|
1890
1966
|
**kwargs,
|
|
1891
1967
|
):
|
|
1892
1968
|
|
|
@@ -1911,14 +1987,22 @@ class BaseMatchModel(BaseModel):
|
|
|
1911
1987
|
dense_features=all_dense_features,
|
|
1912
1988
|
sparse_features=all_sparse_features,
|
|
1913
1989
|
sequence_features=all_sequence_features,
|
|
1914
|
-
target=
|
|
1915
|
-
|
|
1990
|
+
target=target,
|
|
1991
|
+
id_columns=id_columns,
|
|
1992
|
+
task=task,
|
|
1916
1993
|
device=device,
|
|
1917
1994
|
embedding_l1_reg=embedding_l1_reg,
|
|
1918
1995
|
dense_l1_reg=dense_l1_reg,
|
|
1919
1996
|
embedding_l2_reg=embedding_l2_reg,
|
|
1920
1997
|
dense_l2_reg=dense_l2_reg,
|
|
1921
1998
|
early_stop_patience=early_stop_patience,
|
|
1999
|
+
session_id=session_id,
|
|
2000
|
+
callbacks=callbacks,
|
|
2001
|
+
distributed=distributed,
|
|
2002
|
+
rank=rank,
|
|
2003
|
+
world_size=world_size,
|
|
2004
|
+
local_rank=local_rank,
|
|
2005
|
+
ddp_find_unused_parameters=ddp_find_unused_parameters,
|
|
1922
2006
|
**kwargs,
|
|
1923
2007
|
)
|
|
1924
2008
|
|
|
@@ -1989,73 +2073,76 @@ class BaseMatchModel(BaseModel):
|
|
|
1989
2073
|
scheduler_params: dict | None = None,
|
|
1990
2074
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
1991
2075
|
loss_params: dict | list[dict] | None = None,
|
|
2076
|
+
loss_weights: int | float | list[int | float] | None = None,
|
|
2077
|
+
callbacks: list[Callback] | None = None,
|
|
1992
2078
|
):
|
|
1993
2079
|
"""
|
|
1994
|
-
|
|
1995
|
-
|
|
2080
|
+
Configure the match model for training.
|
|
2081
|
+
|
|
2082
|
+
This mirrors `BaseModel.compile()` and additionally validates `training_mode`.
|
|
1996
2083
|
"""
|
|
1997
2084
|
if self.training_mode not in self.support_training_modes:
|
|
1998
2085
|
raise ValueError(
|
|
1999
2086
|
f"{self.model_name} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
|
|
2000
2087
|
)
|
|
2001
|
-
# Call parent compile with match-specific logic
|
|
2002
|
-
optimizer_params = optimizer_params or {}
|
|
2003
|
-
|
|
2004
|
-
self.optimizer_name = (
|
|
2005
|
-
optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
2006
|
-
)
|
|
2007
|
-
self.optimizer_params = optimizer_params
|
|
2008
|
-
if isinstance(scheduler, str):
|
|
2009
|
-
self.scheduler_name = scheduler
|
|
2010
|
-
elif scheduler is not None:
|
|
2011
|
-
# Try to get __name__ first (for class types), then __class__.__name__ (for instances)
|
|
2012
|
-
self.scheduler_name = getattr(
|
|
2013
|
-
scheduler,
|
|
2014
|
-
"__name__",
|
|
2015
|
-
getattr(scheduler.__class__, "__name__", str(scheduler)),
|
|
2016
|
-
)
|
|
2017
|
-
else:
|
|
2018
|
-
self.scheduler_name = None
|
|
2019
|
-
self.scheduler_params = scheduler_params or {}
|
|
2020
|
-
self.loss_config = loss
|
|
2021
|
-
self.loss_params = loss_params or {}
|
|
2022
2088
|
|
|
2023
|
-
|
|
2024
|
-
optimizer=optimizer, params=self.parameters(), **optimizer_params
|
|
2025
|
-
)
|
|
2026
|
-
# Set loss function based on training mode
|
|
2027
|
-
default_losses = {
|
|
2089
|
+
default_loss_by_mode: dict[str, str] = {
|
|
2028
2090
|
"pointwise": "bce",
|
|
2029
2091
|
"pairwise": "bpr",
|
|
2030
2092
|
"listwise": "sampled_softmax",
|
|
2031
2093
|
}
|
|
2032
2094
|
|
|
2033
|
-
|
|
2034
|
-
|
|
2035
|
-
|
|
2036
|
-
|
|
2037
|
-
|
|
2038
|
-
|
|
2039
|
-
|
|
2040
|
-
|
|
2041
|
-
|
|
2042
|
-
|
|
2043
|
-
|
|
2044
|
-
|
|
2045
|
-
|
|
2046
|
-
|
|
2047
|
-
|
|
2048
|
-
|
|
2049
|
-
|
|
2050
|
-
|
|
2051
|
-
|
|
2052
|
-
|
|
2053
|
-
|
|
2054
|
-
|
|
2055
|
-
|
|
2056
|
-
|
|
2095
|
+
effective_loss: str | nn.Module | list[str | nn.Module] | None = loss
|
|
2096
|
+
if effective_loss is None:
|
|
2097
|
+
effective_loss = default_loss_by_mode[self.training_mode]
|
|
2098
|
+
elif isinstance(effective_loss, (str,)):
|
|
2099
|
+
if self.training_mode in {"pairwise", "listwise"} and effective_loss in {
|
|
2100
|
+
"bce",
|
|
2101
|
+
"binary_crossentropy",
|
|
2102
|
+
}:
|
|
2103
|
+
effective_loss = default_loss_by_mode[self.training_mode]
|
|
2104
|
+
elif isinstance(effective_loss, list):
|
|
2105
|
+
if not effective_loss:
|
|
2106
|
+
effective_loss = [default_loss_by_mode[self.training_mode]]
|
|
2107
|
+
else:
|
|
2108
|
+
first = effective_loss[0]
|
|
2109
|
+
if (
|
|
2110
|
+
self.training_mode in {"pairwise", "listwise"}
|
|
2111
|
+
and isinstance(first, str)
|
|
2112
|
+
and first in {"bce", "binary_crossentropy"}
|
|
2113
|
+
):
|
|
2114
|
+
effective_loss = [
|
|
2115
|
+
default_loss_by_mode[self.training_mode],
|
|
2116
|
+
*effective_loss[1:],
|
|
2117
|
+
]
|
|
2118
|
+
return super().compile(
|
|
2119
|
+
optimizer=optimizer,
|
|
2120
|
+
optimizer_params=optimizer_params,
|
|
2121
|
+
scheduler=scheduler,
|
|
2122
|
+
scheduler_params=scheduler_params,
|
|
2123
|
+
loss=effective_loss,
|
|
2124
|
+
loss_params=loss_params,
|
|
2125
|
+
loss_weights=loss_weights,
|
|
2126
|
+
callbacks=callbacks,
|
|
2057
2127
|
)
|
|
2058
2128
|
|
|
2129
|
+
def inbatch_logits(
|
|
2130
|
+
self, user_emb: torch.Tensor, item_emb: torch.Tensor
|
|
2131
|
+
) -> torch.Tensor:
|
|
2132
|
+
if self.similarity_metric == "dot":
|
|
2133
|
+
logits = torch.matmul(user_emb, item_emb.t())
|
|
2134
|
+
elif self.similarity_metric == "cosine":
|
|
2135
|
+
user_norm = F.normalize(user_emb, p=2, dim=-1)
|
|
2136
|
+
item_norm = F.normalize(item_emb, p=2, dim=-1)
|
|
2137
|
+
logits = torch.matmul(user_norm, item_norm.t())
|
|
2138
|
+
elif self.similarity_metric == "euclidean":
|
|
2139
|
+
user_sq = (user_emb**2).sum(dim=1, keepdim=True) # [B, 1]
|
|
2140
|
+
item_sq = (item_emb**2).sum(dim=1, keepdim=True).t() # [1, B]
|
|
2141
|
+
logits = -(user_sq + item_sq - 2.0 * torch.matmul(user_emb, item_emb.t()))
|
|
2142
|
+
else:
|
|
2143
|
+
raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
|
|
2144
|
+
return logits / self.temperature
|
|
2145
|
+
|
|
2059
2146
|
def compute_similarity(
|
|
2060
2147
|
self, user_emb: torch.Tensor, item_emb: torch.Tensor
|
|
2061
2148
|
) -> torch.Tensor:
|
|
@@ -2125,9 +2212,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2125
2212
|
|
|
2126
2213
|
def compute_loss(self, y_pred, y_true):
|
|
2127
2214
|
if self.training_mode == "pointwise":
|
|
2128
|
-
|
|
2129
|
-
return torch.tensor(0.0, device=self.device)
|
|
2130
|
-
return self.loss_fn[0](y_pred, y_true)
|
|
2215
|
+
return super().compute_loss(y_pred, y_true)
|
|
2131
2216
|
|
|
2132
2217
|
# pairwise / listwise using inbatch neg
|
|
2133
2218
|
elif self.training_mode in ["pairwise", "listwise"]:
|
|
@@ -2136,14 +2221,39 @@ class BaseMatchModel(BaseModel):
|
|
|
2136
2221
|
"For pairwise/listwise training, forward should return (user_emb, item_emb). Please check BaseMatchModel.forward implementation."
|
|
2137
2222
|
)
|
|
2138
2223
|
user_emb, item_emb = y_pred # [B, D], [B, D]
|
|
2139
|
-
|
|
2140
|
-
|
|
2141
|
-
|
|
2142
|
-
|
|
2143
|
-
|
|
2144
|
-
|
|
2145
|
-
|
|
2146
|
-
|
|
2224
|
+
batch_size = user_emb.size(0)
|
|
2225
|
+
if batch_size < 2:
|
|
2226
|
+
return torch.tensor(0.0, device=user_emb.device)
|
|
2227
|
+
|
|
2228
|
+
logits = self.inbatch_logits(user_emb, item_emb) # [B, B]
|
|
2229
|
+
|
|
2230
|
+
eye = torch.eye(batch_size, device=logits.device, dtype=torch.bool)
|
|
2231
|
+
pos_logits = logits.diag() # [B]
|
|
2232
|
+
neg_logits = logits.masked_select(~eye).view(
|
|
2233
|
+
batch_size, batch_size - 1
|
|
2234
|
+
) # [B, B-1]
|
|
2235
|
+
|
|
2236
|
+
loss_fn = self.loss_fn[0] if getattr(self, "loss_fn", None) else None
|
|
2237
|
+
if isinstance(loss_fn, SampledSoftmaxLoss):
|
|
2238
|
+
loss = loss_fn(pos_logits, neg_logits)
|
|
2239
|
+
elif isinstance(loss_fn, (BPRLoss, HingeLoss)):
|
|
2240
|
+
loss = loss_fn(pos_logits, neg_logits)
|
|
2241
|
+
elif isinstance(loss_fn, TripletLoss):
|
|
2242
|
+
neg_emb = item_emb.masked_select(~eye.unsqueeze(-1)).view(
|
|
2243
|
+
batch_size, batch_size - 1, item_emb.size(-1)
|
|
2244
|
+
)
|
|
2245
|
+
loss = loss_fn(user_emb, item_emb, neg_emb)
|
|
2246
|
+
elif isinstance(loss_fn, InfoNCELoss) and self.similarity_metric == "dot":
|
|
2247
|
+
neg_emb = item_emb.masked_select(~eye.unsqueeze(-1)).view(
|
|
2248
|
+
batch_size, batch_size - 1, item_emb.size(-1)
|
|
2249
|
+
)
|
|
2250
|
+
loss = loss_fn(user_emb, item_emb, neg_emb)
|
|
2251
|
+
else:
|
|
2252
|
+
targets = torch.arange(batch_size, device=logits.device)
|
|
2253
|
+
loss = F.cross_entropy(logits, targets)
|
|
2254
|
+
|
|
2255
|
+
if self.loss_weights is not None:
|
|
2256
|
+
loss *= float(self.loss_weights[0])
|
|
2147
2257
|
return loss
|
|
2148
2258
|
else:
|
|
2149
2259
|
raise ValueError(f"Unknown training mode: {self.training_mode}")
|
|
@@ -2154,17 +2264,23 @@ class BaseMatchModel(BaseModel):
|
|
|
2154
2264
|
"""Prepare data loader for specific features."""
|
|
2155
2265
|
if isinstance(data, DataLoader):
|
|
2156
2266
|
return data
|
|
2157
|
-
|
|
2158
|
-
|
|
2159
|
-
|
|
2160
|
-
|
|
2161
|
-
|
|
2162
|
-
|
|
2163
|
-
|
|
2164
|
-
|
|
2165
|
-
|
|
2166
|
-
|
|
2167
|
-
|
|
2267
|
+
tensors = build_tensors_from_data(
|
|
2268
|
+
data=data,
|
|
2269
|
+
raw_data=data,
|
|
2270
|
+
features=features,
|
|
2271
|
+
target_columns=[],
|
|
2272
|
+
id_columns=[],
|
|
2273
|
+
)
|
|
2274
|
+
if tensors is None:
|
|
2275
|
+
raise ValueError(
|
|
2276
|
+
"[BaseMatchModel-prepare_feature_data Error] No data available to create DataLoader."
|
|
2277
|
+
)
|
|
2278
|
+
dataset = TensorDictDataset(tensors)
|
|
2279
|
+
return DataLoader(
|
|
2280
|
+
dataset,
|
|
2281
|
+
batch_size=batch_size,
|
|
2282
|
+
shuffle=False,
|
|
2283
|
+
collate_fn=collate_fn,
|
|
2168
2284
|
)
|
|
2169
2285
|
|
|
2170
2286
|
def encode_user(
|