nextrec 0.4.5__py3-none-any.whl → 0.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/callback.py +399 -21
- nextrec/basic/features.py +4 -0
- nextrec/basic/layers.py +103 -24
- nextrec/basic/metrics.py +71 -1
- nextrec/basic/model.py +285 -186
- nextrec/data/data_processing.py +1 -3
- nextrec/loss/loss_utils.py +73 -4
- nextrec/models/generative/__init__.py +16 -0
- nextrec/models/generative/hstu.py +110 -57
- nextrec/models/generative/rqvae.py +826 -0
- nextrec/models/match/dssm.py +5 -4
- nextrec/models/match/dssm_v2.py +4 -3
- nextrec/models/match/mind.py +5 -4
- nextrec/models/match/sdm.py +5 -4
- nextrec/models/match/youtube_dnn.py +5 -4
- nextrec/models/ranking/masknet.py +1 -1
- nextrec/utils/config.py +38 -1
- nextrec/utils/embedding.py +28 -0
- nextrec/utils/initializer.py +4 -4
- nextrec/utils/synthetic_data.py +19 -0
- nextrec-0.4.7.dist-info/METADATA +376 -0
- {nextrec-0.4.5.dist-info → nextrec-0.4.7.dist-info}/RECORD +26 -25
- nextrec-0.4.5.dist-info/METADATA +0 -357
- {nextrec-0.4.5.dist-info → nextrec-0.4.7.dist-info}/WHEEL +0 -0
- {nextrec-0.4.5.dist-info → nextrec-0.4.7.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.5.dist-info → nextrec-0.4.7.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 18/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -25,7 +25,13 @@ from torch.utils.data import DataLoader
|
|
|
25
25
|
from torch.utils.data.distributed import DistributedSampler
|
|
26
26
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
27
27
|
|
|
28
|
-
from nextrec.basic.callback import
|
|
28
|
+
from nextrec.basic.callback import (
|
|
29
|
+
EarlyStopper,
|
|
30
|
+
CallbackList,
|
|
31
|
+
Callback,
|
|
32
|
+
CheckpointSaver,
|
|
33
|
+
LearningRateScheduler,
|
|
34
|
+
)
|
|
29
35
|
from nextrec.basic.features import (
|
|
30
36
|
DenseFeature,
|
|
31
37
|
SparseFeature,
|
|
@@ -42,7 +48,15 @@ from nextrec.data.dataloader import build_tensors_from_data
|
|
|
42
48
|
from nextrec.data.batch_utils import collate_fn, batch_to_dict
|
|
43
49
|
from nextrec.data.data_processing import get_column_data, get_user_ids
|
|
44
50
|
|
|
45
|
-
from nextrec.loss import
|
|
51
|
+
from nextrec.loss import (
|
|
52
|
+
BPRLoss,
|
|
53
|
+
HingeLoss,
|
|
54
|
+
InfoNCELoss,
|
|
55
|
+
SampledSoftmaxLoss,
|
|
56
|
+
TripletLoss,
|
|
57
|
+
get_loss_fn,
|
|
58
|
+
get_loss_kwargs,
|
|
59
|
+
)
|
|
46
60
|
from nextrec.utils.tensor import to_tensor
|
|
47
61
|
from nextrec.utils.device import configure_device
|
|
48
62
|
from nextrec.utils.optimizer import get_optimizer, get_scheduler
|
|
@@ -71,13 +85,14 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
71
85
|
target: list[str] | str | None = None,
|
|
72
86
|
id_columns: list[str] | str | None = None,
|
|
73
87
|
task: str | list[str] | None = None,
|
|
74
|
-
device: str = "cpu",
|
|
75
|
-
early_stop_patience: int = 20,
|
|
76
|
-
session_id: str | None = None,
|
|
77
88
|
embedding_l1_reg: float = 0.0,
|
|
78
89
|
dense_l1_reg: float = 0.0,
|
|
79
90
|
embedding_l2_reg: float = 0.0,
|
|
80
91
|
dense_l2_reg: float = 0.0,
|
|
92
|
+
device: str = "cpu",
|
|
93
|
+
early_stop_patience: int = 20,
|
|
94
|
+
session_id: str | None = None,
|
|
95
|
+
callbacks: list[Callback] | None = None,
|
|
81
96
|
distributed: bool = False,
|
|
82
97
|
rank: int | None = None,
|
|
83
98
|
world_size: int | None = None,
|
|
@@ -91,16 +106,20 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
91
106
|
dense_features: DenseFeature definitions.
|
|
92
107
|
sparse_features: SparseFeature definitions.
|
|
93
108
|
sequence_features: SequenceFeature definitions.
|
|
94
|
-
target: Target column name.
|
|
95
|
-
id_columns: Identifier column name, only need to specify if GAUC is required.
|
|
109
|
+
target: Target column name. e.g., 'label' or ['label1', 'label2'].
|
|
110
|
+
id_columns: Identifier column name, only need to specify if GAUC is required. e.g., 'user_id'.
|
|
96
111
|
task: Task types, e.g., 'binary', 'regression', or ['binary', 'regression']. If None, falls back to self.default_task.
|
|
97
|
-
|
|
112
|
+
|
|
98
113
|
embedding_l1_reg: L1 regularization strength for embedding params. e.g., 1e-6.
|
|
99
114
|
dense_l1_reg: L1 regularization strength for dense params. e.g., 1e-5.
|
|
100
115
|
embedding_l2_reg: L2 regularization strength for embedding params. e.g., 1e-5.
|
|
101
116
|
dense_l2_reg: L2 regularization strength for dense params. e.g., 1e-4.
|
|
117
|
+
|
|
118
|
+
device: Torch device string or torch.device. e.g., 'cpu', 'cuda:0'.
|
|
102
119
|
early_stop_patience: Epochs for early stopping. 0 to disable. e.g., 20.
|
|
103
|
-
session_id: Session id for logging. If None, a default id with timestamps will be created.
|
|
120
|
+
session_id: Session id for logging. If None, a default id with timestamps will be created. e.g., 'session_tutorial'.
|
|
121
|
+
callbacks: List of callback instances. If None, default callbacks will be created. e.g., [EarlyStopper(), CheckpointSaver()].
|
|
122
|
+
|
|
104
123
|
distributed: Enable DistributedDataParallel flow, set True to enable distributed training.
|
|
105
124
|
rank: Global rank (defaults to env RANK).
|
|
106
125
|
world_size: Number of processes (defaults to env WORLD_SIZE).
|
|
@@ -126,11 +145,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
126
145
|
self.session = create_session(session_id)
|
|
127
146
|
self.session_path = self.session.root # pwd/session_id, path for this session
|
|
128
147
|
self.checkpoint_path = os.path.join(
|
|
129
|
-
self.session_path, self.model_name + "_checkpoint.
|
|
130
|
-
) # example: pwd/session_id/DeepFM_checkpoint.
|
|
131
|
-
self.best_path = os.path.join(
|
|
132
|
-
self.session_path, self.model_name + "_best.model"
|
|
133
|
-
)
|
|
148
|
+
self.session_path, self.model_name + "_checkpoint.pt"
|
|
149
|
+
) # example: pwd/session_id/DeepFM_checkpoint.pt
|
|
150
|
+
self.best_path = os.path.join(self.session_path, self.model_name + "_best.pt")
|
|
134
151
|
self.features_config_path = os.path.join(
|
|
135
152
|
self.session_path, "features_config.pkl"
|
|
136
153
|
)
|
|
@@ -153,6 +170,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
153
170
|
self.max_gradient_norm = 1.0
|
|
154
171
|
self.logger_initialized = False
|
|
155
172
|
self.training_logger = None
|
|
173
|
+
self.callbacks = CallbackList(callbacks) if callbacks else CallbackList()
|
|
156
174
|
|
|
157
175
|
def register_regularization_weights(
|
|
158
176
|
self,
|
|
@@ -164,8 +182,22 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
164
182
|
include_modules = include_modules or []
|
|
165
183
|
embedding_layer = getattr(self, embedding_attr, None)
|
|
166
184
|
embed_dict = getattr(embedding_layer, "embed_dict", None)
|
|
185
|
+
embedding_params: list[torch.Tensor] = []
|
|
167
186
|
if embed_dict is not None:
|
|
168
|
-
|
|
187
|
+
embedding_params.extend(
|
|
188
|
+
embed.weight for embed in embed_dict.values() if hasattr(embed, "weight")
|
|
189
|
+
)
|
|
190
|
+
else:
|
|
191
|
+
weight = getattr(embedding_layer, "weight", None)
|
|
192
|
+
if isinstance(weight, torch.Tensor):
|
|
193
|
+
embedding_params.append(weight)
|
|
194
|
+
|
|
195
|
+
existing_embedding_ids = {id(param) for param in self.embedding_params}
|
|
196
|
+
for param in embedding_params:
|
|
197
|
+
if id(param) not in existing_embedding_ids:
|
|
198
|
+
self.embedding_params.append(param)
|
|
199
|
+
existing_embedding_ids.add(id(param))
|
|
200
|
+
|
|
169
201
|
skip_types = (
|
|
170
202
|
nn.BatchNorm1d,
|
|
171
203
|
nn.BatchNorm2d,
|
|
@@ -174,6 +206,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
174
206
|
nn.Dropout2d,
|
|
175
207
|
nn.Dropout3d,
|
|
176
208
|
)
|
|
209
|
+
existing_reg_ids = {id(param) for param in self.regularization_weights}
|
|
177
210
|
for name, module in self.named_modules():
|
|
178
211
|
if (
|
|
179
212
|
module is self
|
|
@@ -184,7 +217,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
184
217
|
):
|
|
185
218
|
continue
|
|
186
219
|
if isinstance(module, nn.Linear):
|
|
187
|
-
|
|
220
|
+
if id(module.weight) not in existing_reg_ids:
|
|
221
|
+
self.regularization_weights.append(module.weight)
|
|
222
|
+
existing_reg_ids.add(id(module.weight))
|
|
188
223
|
|
|
189
224
|
def add_reg_loss(self) -> torch.Tensor:
|
|
190
225
|
reg_loss = torch.tensor(0.0, device=self.device)
|
|
@@ -337,6 +372,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
337
372
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
338
373
|
loss_params: dict | list[dict] | None = None,
|
|
339
374
|
loss_weights: int | float | list[int | float] | None = None,
|
|
375
|
+
callbacks: list[Callback] | None = None,
|
|
340
376
|
):
|
|
341
377
|
"""
|
|
342
378
|
Configure the model for training.
|
|
@@ -348,6 +384,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
348
384
|
loss: Loss function name, instance, or list for multi-task. e.g., 'bce', 'mse', or torch.nn.BCELoss(), you can also use custom loss functions.
|
|
349
385
|
loss_params: Loss function parameters, or list for multi-task. e.g., {'weight': tensor([0.25, 0.75])}.
|
|
350
386
|
loss_weights: Weights for each task loss, int/float for single-task or list for multi-task. e.g., 1.0, or [1.0, 0.5].
|
|
387
|
+
callbacks: Additional callbacks to add to the existing callback list. e.g., [EarlyStopper(), CheckpointSaver()].
|
|
351
388
|
"""
|
|
352
389
|
if loss_params is None:
|
|
353
390
|
self.loss_params = {}
|
|
@@ -429,6 +466,11 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
429
466
|
)
|
|
430
467
|
self.loss_weights = weights
|
|
431
468
|
|
|
469
|
+
# Add callbacks from compile if provided
|
|
470
|
+
if callbacks:
|
|
471
|
+
for callback in callbacks:
|
|
472
|
+
self.callbacks.append(callback)
|
|
473
|
+
|
|
432
474
|
def compute_loss(self, y_pred, y_true):
|
|
433
475
|
if y_true is None:
|
|
434
476
|
raise ValueError(
|
|
@@ -582,6 +624,53 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
582
624
|
task=self.task, metrics=metrics, target_names=self.target_columns
|
|
583
625
|
)
|
|
584
626
|
) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
|
|
627
|
+
|
|
628
|
+
# Setup default callbacks if none exist
|
|
629
|
+
if len(self.callbacks.callbacks) == 0:
|
|
630
|
+
if self.nums_task == 1:
|
|
631
|
+
monitor_metric = f"val_{self.metrics[0]}"
|
|
632
|
+
else:
|
|
633
|
+
monitor_metric = f"val_{self.metrics[0]}_{self.target_columns[0]}"
|
|
634
|
+
|
|
635
|
+
if self.early_stop_patience > 0:
|
|
636
|
+
self.callbacks.append(
|
|
637
|
+
EarlyStopper(
|
|
638
|
+
monitor=monitor_metric,
|
|
639
|
+
patience=self.early_stop_patience,
|
|
640
|
+
mode=self.best_metrics_mode,
|
|
641
|
+
restore_best_weights=not self.distributed,
|
|
642
|
+
verbose=1 if self.is_main_process else 0,
|
|
643
|
+
)
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
if self.is_main_process:
|
|
647
|
+
self.callbacks.append(
|
|
648
|
+
CheckpointSaver(
|
|
649
|
+
save_path=self.best_path,
|
|
650
|
+
monitor=monitor_metric,
|
|
651
|
+
mode=self.best_metrics_mode,
|
|
652
|
+
save_best_only=True,
|
|
653
|
+
verbose=1,
|
|
654
|
+
)
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
if self.scheduler_fn is not None:
|
|
658
|
+
self.callbacks.append(
|
|
659
|
+
LearningRateScheduler(
|
|
660
|
+
scheduler=self.scheduler_fn,
|
|
661
|
+
verbose=1 if self.is_main_process else 0,
|
|
662
|
+
)
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
self.callbacks.set_model(self)
|
|
666
|
+
self.callbacks.set_params(
|
|
667
|
+
{
|
|
668
|
+
"epochs": epochs,
|
|
669
|
+
"batch_size": batch_size,
|
|
670
|
+
"metrics": self.metrics,
|
|
671
|
+
}
|
|
672
|
+
)
|
|
673
|
+
|
|
585
674
|
self.early_stopper = EarlyStopper(
|
|
586
675
|
patience=self.early_stop_patience, mode=self.best_metrics_mode
|
|
587
676
|
)
|
|
@@ -650,7 +739,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
650
739
|
else:
|
|
651
740
|
train_loader = train_data
|
|
652
741
|
else:
|
|
653
|
-
|
|
742
|
+
result = self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, return_dataset=True)
|
|
743
|
+
assert isinstance(result, tuple), "Expected tuple from prepare_data_loader with return_dataset=True"
|
|
744
|
+
loader, dataset = result
|
|
654
745
|
if (
|
|
655
746
|
auto_distributed_sampler
|
|
656
747
|
and self.distributed
|
|
@@ -731,8 +822,13 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
731
822
|
logging.info("")
|
|
732
823
|
logging.info(colorize(f"Model device: {self.device}", bold=True))
|
|
733
824
|
|
|
825
|
+
self.callbacks.on_train_begin()
|
|
826
|
+
|
|
734
827
|
for epoch in range(epochs):
|
|
735
828
|
self.epoch_index = epoch
|
|
829
|
+
|
|
830
|
+
self.callbacks.on_epoch_begin(epoch)
|
|
831
|
+
|
|
736
832
|
if is_streaming and self.is_main_process:
|
|
737
833
|
logging.info("")
|
|
738
834
|
logging.info(
|
|
@@ -742,10 +838,14 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
742
838
|
# handle train result
|
|
743
839
|
if (
|
|
744
840
|
self.distributed
|
|
841
|
+
and isinstance(train_loader, DataLoader)
|
|
745
842
|
and hasattr(train_loader, "sampler")
|
|
746
843
|
and isinstance(train_loader.sampler, DistributedSampler)
|
|
747
844
|
):
|
|
748
845
|
train_loader.sampler.set_epoch(epoch)
|
|
846
|
+
# Type guard: ensure train_loader is DataLoader for train_epoch
|
|
847
|
+
if not isinstance(train_loader, DataLoader):
|
|
848
|
+
raise TypeError(f"Expected DataLoader for training, got {type(train_loader)}")
|
|
749
849
|
train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
|
|
750
850
|
if isinstance(train_result, tuple): # [avg_loss, metrics_dict]
|
|
751
851
|
train_loss, train_metrics = train_result
|
|
@@ -805,6 +905,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
805
905
|
train_log_payload, step=epoch + 1, split="train"
|
|
806
906
|
)
|
|
807
907
|
if valid_loader is not None:
|
|
908
|
+
# Call on_validation_begin
|
|
909
|
+
self.callbacks.on_validation_begin()
|
|
910
|
+
|
|
808
911
|
# pass user_ids only if needed for GAUC metric
|
|
809
912
|
val_metrics = self.evaluate(
|
|
810
913
|
valid_loader,
|
|
@@ -851,17 +954,17 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
851
954
|
color="cyan",
|
|
852
955
|
)
|
|
853
956
|
)
|
|
957
|
+
|
|
958
|
+
# Call on_validation_end
|
|
959
|
+
self.callbacks.on_validation_end()
|
|
854
960
|
if val_metrics and self.training_logger:
|
|
855
961
|
self.training_logger.log_metrics(
|
|
856
962
|
val_metrics, step=epoch + 1, split="valid"
|
|
857
963
|
)
|
|
964
|
+
|
|
858
965
|
# Handle empty validation metrics
|
|
859
966
|
if not val_metrics:
|
|
860
967
|
if self.is_main_process:
|
|
861
|
-
self.save_model(
|
|
862
|
-
self.checkpoint_path, add_timestamp=False, verbose=False
|
|
863
|
-
)
|
|
864
|
-
self.best_checkpoint_path = self.checkpoint_path
|
|
865
968
|
logging.info(
|
|
866
969
|
colorize(
|
|
867
970
|
"Warning: No validation metrics computed. Skipping validation for this epoch.",
|
|
@@ -869,81 +972,26 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
869
972
|
)
|
|
870
973
|
)
|
|
871
974
|
continue
|
|
872
|
-
if self.nums_task == 1:
|
|
873
|
-
primary_metric_key = self.metrics[0]
|
|
874
|
-
else:
|
|
875
|
-
primary_metric_key = f"{self.metrics[0]}_{self.target_columns[0]}"
|
|
876
|
-
primary_metric = val_metrics.get(
|
|
877
|
-
primary_metric_key, val_metrics[list(val_metrics.keys())[0]]
|
|
878
|
-
) # get primary metric value, default to first metric if not found
|
|
879
|
-
|
|
880
|
-
# In distributed mode, broadcast primary_metric to ensure all processes use the same value
|
|
881
|
-
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
882
|
-
metric_tensor = torch.tensor(
|
|
883
|
-
[primary_metric], device=self.device, dtype=torch.float32
|
|
884
|
-
)
|
|
885
|
-
dist.broadcast(metric_tensor, src=0)
|
|
886
|
-
primary_metric = float(metric_tensor.item())
|
|
887
|
-
|
|
888
|
-
improved = False
|
|
889
|
-
# early stopping check
|
|
890
|
-
if self.best_metrics_mode == "max":
|
|
891
|
-
if primary_metric > self.best_metric:
|
|
892
|
-
self.best_metric = primary_metric
|
|
893
|
-
improved = True
|
|
894
|
-
else:
|
|
895
|
-
if primary_metric < self.best_metric:
|
|
896
|
-
self.best_metric = primary_metric
|
|
897
|
-
improved = True
|
|
898
975
|
|
|
899
|
-
#
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
)
|
|
904
|
-
|
|
905
|
-
if improved:
|
|
906
|
-
logging.info(
|
|
907
|
-
colorize(
|
|
908
|
-
f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"
|
|
909
|
-
)
|
|
910
|
-
)
|
|
911
|
-
self.save_model(
|
|
912
|
-
self.best_path, add_timestamp=False, verbose=False
|
|
913
|
-
)
|
|
914
|
-
self.best_checkpoint_path = self.best_path
|
|
915
|
-
self.early_stopper.trial_counter = 0
|
|
916
|
-
else:
|
|
917
|
-
self.early_stopper.trial_counter += 1
|
|
918
|
-
logging.info(
|
|
919
|
-
colorize(
|
|
920
|
-
f"No improvement for {self.early_stopper.trial_counter} epoch(s)"
|
|
921
|
-
)
|
|
922
|
-
)
|
|
923
|
-
if self.early_stopper.trial_counter >= self.early_stopper.patience:
|
|
924
|
-
self.stop_training = True
|
|
925
|
-
logging.info(
|
|
926
|
-
colorize(
|
|
927
|
-
f"Early stopping triggered after {epoch + 1} epochs",
|
|
928
|
-
color="bright_red",
|
|
929
|
-
bold=True,
|
|
930
|
-
)
|
|
931
|
-
)
|
|
932
|
-
else:
|
|
933
|
-
# Non-main processes also update trial_counter to keep in sync
|
|
934
|
-
if improved:
|
|
935
|
-
self.early_stopper.trial_counter = 0
|
|
936
|
-
else:
|
|
937
|
-
self.early_stopper.trial_counter += 1
|
|
976
|
+
# Prepare epoch logs for callbacks
|
|
977
|
+
epoch_logs = {**train_log_payload}
|
|
978
|
+
if val_metrics:
|
|
979
|
+
# Add val_ prefix to validation metrics
|
|
980
|
+
for k, v in val_metrics.items():
|
|
981
|
+
epoch_logs[f"val_{k}"] = v
|
|
938
982
|
else:
|
|
983
|
+
# No validation data
|
|
984
|
+
epoch_logs = {**train_log_payload}
|
|
939
985
|
if self.is_main_process:
|
|
940
986
|
self.save_model(
|
|
941
987
|
self.checkpoint_path, add_timestamp=False, verbose=False
|
|
942
988
|
)
|
|
943
|
-
self.
|
|
944
|
-
self.best_checkpoint_path = self.best_path
|
|
989
|
+
self.best_checkpoint_path = self.checkpoint_path
|
|
945
990
|
|
|
946
|
-
#
|
|
991
|
+
# Call on_epoch_end for all callbacks (handles early stopping, checkpointing, lr scheduling)
|
|
992
|
+
self.callbacks.on_epoch_end(epoch, epoch_logs)
|
|
993
|
+
|
|
994
|
+
# Broadcast stop_training flag to all processes
|
|
947
995
|
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
948
996
|
stop_tensor = torch.tensor(
|
|
949
997
|
[int(self.stop_training)], device=self.device
|
|
@@ -953,14 +1001,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
953
1001
|
|
|
954
1002
|
if self.stop_training:
|
|
955
1003
|
break
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
):
|
|
960
|
-
if valid_loader is not None:
|
|
961
|
-
self.scheduler_fn.step(primary_metric)
|
|
962
|
-
else:
|
|
963
|
-
self.scheduler_fn.step()
|
|
1004
|
+
# Call on_train_end for all callbacks
|
|
1005
|
+
self.callbacks.on_train_end()
|
|
1006
|
+
|
|
964
1007
|
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
965
1008
|
dist.barrier() # dist.barrier() will wait for all processes, like async all_reduce()
|
|
966
1009
|
if self.is_main_process:
|
|
@@ -972,9 +1015,17 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
972
1015
|
logging.info(
|
|
973
1016
|
colorize(f"Load best model from: {self.best_checkpoint_path}")
|
|
974
1017
|
)
|
|
975
|
-
self.
|
|
976
|
-
self.
|
|
977
|
-
|
|
1018
|
+
if os.path.exists(self.best_checkpoint_path):
|
|
1019
|
+
self.load_model(
|
|
1020
|
+
self.best_checkpoint_path, map_location=self.device, verbose=False
|
|
1021
|
+
)
|
|
1022
|
+
elif self.is_main_process:
|
|
1023
|
+
logging.info(
|
|
1024
|
+
colorize(
|
|
1025
|
+
f"Warning: Best checkpoint not found at {self.best_checkpoint_path}, skip loading best model.",
|
|
1026
|
+
color="yellow",
|
|
1027
|
+
)
|
|
1028
|
+
)
|
|
978
1029
|
if self.training_logger:
|
|
979
1030
|
self.training_logger.close()
|
|
980
1031
|
return self
|
|
@@ -1563,7 +1614,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1563
1614
|
path=save_path,
|
|
1564
1615
|
default_dir=self.session_path,
|
|
1565
1616
|
default_name=self.model_name,
|
|
1566
|
-
suffix=".
|
|
1617
|
+
suffix=".pt",
|
|
1567
1618
|
add_timestamp=add_timestamp,
|
|
1568
1619
|
)
|
|
1569
1620
|
model_path = Path(target_path)
|
|
@@ -1603,16 +1654,16 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1603
1654
|
self.to(self.device)
|
|
1604
1655
|
base_path = Path(save_path)
|
|
1605
1656
|
if base_path.is_dir():
|
|
1606
|
-
model_files = sorted(base_path.glob("*.
|
|
1657
|
+
model_files = sorted(base_path.glob("*.pt"))
|
|
1607
1658
|
if not model_files:
|
|
1608
1659
|
raise FileNotFoundError(
|
|
1609
|
-
f"[BaseModel-load-model Error] No *.
|
|
1660
|
+
f"[BaseModel-load-model Error] No *.pt file found in directory: {base_path}"
|
|
1610
1661
|
)
|
|
1611
1662
|
model_path = model_files[-1]
|
|
1612
1663
|
config_dir = base_path
|
|
1613
1664
|
else:
|
|
1614
1665
|
model_path = (
|
|
1615
|
-
base_path.with_suffix(".
|
|
1666
|
+
base_path.with_suffix(".pt") if base_path.suffix == "" else base_path
|
|
1616
1667
|
)
|
|
1617
1668
|
config_dir = model_path.parent
|
|
1618
1669
|
if not model_path.exists():
|
|
@@ -1665,21 +1716,21 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1665
1716
|
) -> "BaseModel":
|
|
1666
1717
|
"""
|
|
1667
1718
|
Load a model from a checkpoint path. The checkpoint path should contain:
|
|
1668
|
-
a .
|
|
1719
|
+
a .pt file and a features_config.pkl file.
|
|
1669
1720
|
"""
|
|
1670
1721
|
base_path = Path(checkpoint_path)
|
|
1671
1722
|
verbose = kwargs.pop("verbose", True)
|
|
1672
1723
|
if base_path.is_dir():
|
|
1673
|
-
model_candidates = sorted(base_path.glob("*.
|
|
1724
|
+
model_candidates = sorted(base_path.glob("*.pt"))
|
|
1674
1725
|
if not model_candidates:
|
|
1675
1726
|
raise FileNotFoundError(
|
|
1676
|
-
f"[BaseModel-from-checkpoint Error] No *.
|
|
1727
|
+
f"[BaseModel-from-checkpoint Error] No *.pt file found under: {base_path}"
|
|
1677
1728
|
)
|
|
1678
1729
|
model_file = model_candidates[-1]
|
|
1679
1730
|
config_dir = base_path
|
|
1680
1731
|
else:
|
|
1681
1732
|
model_file = (
|
|
1682
|
-
base_path.with_suffix(".
|
|
1733
|
+
base_path.with_suffix(".pt") if base_path.suffix == "" else base_path
|
|
1683
1734
|
)
|
|
1684
1735
|
config_dir = model_file.parent
|
|
1685
1736
|
features_config_path = config_dir / "features_config.pkl"
|
|
@@ -1849,7 +1900,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1849
1900
|
class BaseMatchModel(BaseModel):
|
|
1850
1901
|
"""
|
|
1851
1902
|
Base class for match (retrieval/recall) models
|
|
1852
|
-
|
|
1903
|
+
|
|
1904
|
+
- Pointwise: predicts a user-item match score/probability using labels (default target: 'label')
|
|
1905
|
+
- Pairwise/Listwise: trains with in-batch negatives; labels can be omitted by setting target=None
|
|
1853
1906
|
"""
|
|
1854
1907
|
|
|
1855
1908
|
@property
|
|
@@ -1889,6 +1942,16 @@ class BaseMatchModel(BaseModel):
|
|
|
1889
1942
|
embedding_l2_reg: float = 0.0,
|
|
1890
1943
|
dense_l2_reg: float = 0.0,
|
|
1891
1944
|
early_stop_patience: int = 20,
|
|
1945
|
+
target: list[str] | str | None = "label",
|
|
1946
|
+
id_columns: list[str] | str | None = None,
|
|
1947
|
+
task: str | list[str] | None = None,
|
|
1948
|
+
session_id: str | None = None,
|
|
1949
|
+
callbacks: list[Callback] | None = None,
|
|
1950
|
+
distributed: bool = False,
|
|
1951
|
+
rank: int | None = None,
|
|
1952
|
+
world_size: int | None = None,
|
|
1953
|
+
local_rank: int | None = None,
|
|
1954
|
+
ddp_find_unused_parameters: bool = False,
|
|
1892
1955
|
**kwargs,
|
|
1893
1956
|
):
|
|
1894
1957
|
|
|
@@ -1913,14 +1976,22 @@ class BaseMatchModel(BaseModel):
|
|
|
1913
1976
|
dense_features=all_dense_features,
|
|
1914
1977
|
sparse_features=all_sparse_features,
|
|
1915
1978
|
sequence_features=all_sequence_features,
|
|
1916
|
-
target=
|
|
1917
|
-
|
|
1979
|
+
target=target,
|
|
1980
|
+
id_columns=id_columns,
|
|
1981
|
+
task=task,
|
|
1918
1982
|
device=device,
|
|
1919
1983
|
embedding_l1_reg=embedding_l1_reg,
|
|
1920
1984
|
dense_l1_reg=dense_l1_reg,
|
|
1921
1985
|
embedding_l2_reg=embedding_l2_reg,
|
|
1922
1986
|
dense_l2_reg=dense_l2_reg,
|
|
1923
1987
|
early_stop_patience=early_stop_patience,
|
|
1988
|
+
session_id=session_id,
|
|
1989
|
+
callbacks=callbacks,
|
|
1990
|
+
distributed=distributed,
|
|
1991
|
+
rank=rank,
|
|
1992
|
+
world_size=world_size,
|
|
1993
|
+
local_rank=local_rank,
|
|
1994
|
+
ddp_find_unused_parameters=ddp_find_unused_parameters,
|
|
1924
1995
|
**kwargs,
|
|
1925
1996
|
)
|
|
1926
1997
|
|
|
@@ -1991,73 +2062,74 @@ class BaseMatchModel(BaseModel):
|
|
|
1991
2062
|
scheduler_params: dict | None = None,
|
|
1992
2063
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
1993
2064
|
loss_params: dict | list[dict] | None = None,
|
|
2065
|
+
loss_weights: int | float | list[int | float] | None = None,
|
|
2066
|
+
callbacks: list[Callback] | None = None,
|
|
1994
2067
|
):
|
|
1995
2068
|
"""
|
|
1996
|
-
|
|
1997
|
-
|
|
2069
|
+
Configure the match model for training.
|
|
2070
|
+
|
|
2071
|
+
This mirrors `BaseModel.compile()` and additionally validates `training_mode`.
|
|
1998
2072
|
"""
|
|
1999
2073
|
if self.training_mode not in self.support_training_modes:
|
|
2000
2074
|
raise ValueError(
|
|
2001
2075
|
f"{self.model_name} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
|
|
2002
2076
|
)
|
|
2003
|
-
# Call parent compile with match-specific logic
|
|
2004
|
-
optimizer_params = optimizer_params or {}
|
|
2005
2077
|
|
|
2006
|
-
|
|
2007
|
-
optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
2008
|
-
)
|
|
2009
|
-
self.optimizer_params = optimizer_params
|
|
2010
|
-
if isinstance(scheduler, str):
|
|
2011
|
-
self.scheduler_name = scheduler
|
|
2012
|
-
elif scheduler is not None:
|
|
2013
|
-
# Try to get __name__ first (for class types), then __class__.__name__ (for instances)
|
|
2014
|
-
self.scheduler_name = getattr(
|
|
2015
|
-
scheduler,
|
|
2016
|
-
"__name__",
|
|
2017
|
-
getattr(scheduler.__class__, "__name__", str(scheduler)),
|
|
2018
|
-
)
|
|
2019
|
-
else:
|
|
2020
|
-
self.scheduler_name = None
|
|
2021
|
-
self.scheduler_params = scheduler_params or {}
|
|
2022
|
-
self.loss_config = loss
|
|
2023
|
-
self.loss_params = loss_params or {}
|
|
2024
|
-
|
|
2025
|
-
self.optimizer_fn = get_optimizer(
|
|
2026
|
-
optimizer=optimizer, params=self.parameters(), **optimizer_params
|
|
2027
|
-
)
|
|
2028
|
-
# Set loss function based on training mode
|
|
2029
|
-
default_losses = {
|
|
2078
|
+
default_loss_by_mode: dict[str, str] = {
|
|
2030
2079
|
"pointwise": "bce",
|
|
2031
2080
|
"pairwise": "bpr",
|
|
2032
2081
|
"listwise": "sampled_softmax",
|
|
2033
2082
|
}
|
|
2034
2083
|
|
|
2035
|
-
|
|
2036
|
-
|
|
2037
|
-
|
|
2038
|
-
|
|
2039
|
-
|
|
2040
|
-
|
|
2041
|
-
|
|
2042
|
-
)
|
|
2043
|
-
|
|
2044
|
-
|
|
2045
|
-
|
|
2046
|
-
|
|
2047
|
-
|
|
2048
|
-
|
|
2049
|
-
|
|
2050
|
-
|
|
2051
|
-
|
|
2052
|
-
|
|
2053
|
-
|
|
2054
|
-
|
|
2055
|
-
|
|
2056
|
-
|
|
2057
|
-
|
|
2058
|
-
|
|
2084
|
+
effective_loss: str | nn.Module | list[str | nn.Module] | None = loss
|
|
2085
|
+
if effective_loss is None:
|
|
2086
|
+
effective_loss = default_loss_by_mode[self.training_mode]
|
|
2087
|
+
elif isinstance(effective_loss, (str,)):
|
|
2088
|
+
if (
|
|
2089
|
+
self.training_mode in {"pairwise", "listwise"}
|
|
2090
|
+
and effective_loss in {"bce", "binary_crossentropy"}
|
|
2091
|
+
):
|
|
2092
|
+
effective_loss = default_loss_by_mode[self.training_mode]
|
|
2093
|
+
elif isinstance(effective_loss, list):
|
|
2094
|
+
if not effective_loss:
|
|
2095
|
+
effective_loss = [default_loss_by_mode[self.training_mode]]
|
|
2096
|
+
else:
|
|
2097
|
+
first = effective_loss[0]
|
|
2098
|
+
if (
|
|
2099
|
+
self.training_mode in {"pairwise", "listwise"}
|
|
2100
|
+
and isinstance(first, str)
|
|
2101
|
+
and first in {"bce", "binary_crossentropy"}
|
|
2102
|
+
):
|
|
2103
|
+
effective_loss = [
|
|
2104
|
+
default_loss_by_mode[self.training_mode],
|
|
2105
|
+
*effective_loss[1:],
|
|
2106
|
+
]
|
|
2107
|
+
return super().compile(
|
|
2108
|
+
optimizer=optimizer,
|
|
2109
|
+
optimizer_params=optimizer_params,
|
|
2110
|
+
scheduler=scheduler,
|
|
2111
|
+
scheduler_params=scheduler_params,
|
|
2112
|
+
loss=effective_loss,
|
|
2113
|
+
loss_params=loss_params,
|
|
2114
|
+
loss_weights=loss_weights,
|
|
2115
|
+
callbacks=callbacks,
|
|
2059
2116
|
)
|
|
2060
2117
|
|
|
2118
|
+
def inbatch_logits(self, user_emb: torch.Tensor, item_emb: torch.Tensor) -> torch.Tensor:
|
|
2119
|
+
if self.similarity_metric == "dot":
|
|
2120
|
+
logits = torch.matmul(user_emb, item_emb.t())
|
|
2121
|
+
elif self.similarity_metric == "cosine":
|
|
2122
|
+
user_norm = F.normalize(user_emb, p=2, dim=-1)
|
|
2123
|
+
item_norm = F.normalize(item_emb, p=2, dim=-1)
|
|
2124
|
+
logits = torch.matmul(user_norm, item_norm.t())
|
|
2125
|
+
elif self.similarity_metric == "euclidean":
|
|
2126
|
+
user_sq = (user_emb**2).sum(dim=1, keepdim=True) # [B, 1]
|
|
2127
|
+
item_sq = (item_emb**2).sum(dim=1, keepdim=True).t() # [1, B]
|
|
2128
|
+
logits = -(user_sq + item_sq - 2.0 * torch.matmul(user_emb, item_emb.t()))
|
|
2129
|
+
else:
|
|
2130
|
+
raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
|
|
2131
|
+
return logits / self.temperature
|
|
2132
|
+
|
|
2061
2133
|
def compute_similarity(
|
|
2062
2134
|
self, user_emb: torch.Tensor, item_emb: torch.Tensor
|
|
2063
2135
|
) -> torch.Tensor:
|
|
@@ -2127,9 +2199,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2127
2199
|
|
|
2128
2200
|
def compute_loss(self, y_pred, y_true):
|
|
2129
2201
|
if self.training_mode == "pointwise":
|
|
2130
|
-
|
|
2131
|
-
return torch.tensor(0.0, device=self.device)
|
|
2132
|
-
return self.loss_fn[0](y_pred, y_true)
|
|
2202
|
+
return super().compute_loss(y_pred, y_true)
|
|
2133
2203
|
|
|
2134
2204
|
# pairwise / listwise using inbatch neg
|
|
2135
2205
|
elif self.training_mode in ["pairwise", "listwise"]:
|
|
@@ -2138,14 +2208,37 @@ class BaseMatchModel(BaseModel):
|
|
|
2138
2208
|
"For pairwise/listwise training, forward should return (user_emb, item_emb). Please check BaseMatchModel.forward implementation."
|
|
2139
2209
|
)
|
|
2140
2210
|
user_emb, item_emb = y_pred # [B, D], [B, D]
|
|
2141
|
-
|
|
2142
|
-
|
|
2143
|
-
|
|
2144
|
-
|
|
2145
|
-
|
|
2146
|
-
|
|
2147
|
-
|
|
2148
|
-
|
|
2211
|
+
batch_size = user_emb.size(0)
|
|
2212
|
+
if batch_size < 2:
|
|
2213
|
+
return torch.tensor(0.0, device=user_emb.device)
|
|
2214
|
+
|
|
2215
|
+
logits = self.inbatch_logits(user_emb, item_emb) # [B, B]
|
|
2216
|
+
|
|
2217
|
+
eye = torch.eye(batch_size, device=logits.device, dtype=torch.bool)
|
|
2218
|
+
pos_logits = logits.diag() # [B]
|
|
2219
|
+
neg_logits = logits.masked_select(~eye).view(batch_size, batch_size - 1) # [B, B-1]
|
|
2220
|
+
|
|
2221
|
+
loss_fn = self.loss_fn[0] if getattr(self, "loss_fn", None) else None
|
|
2222
|
+
if isinstance(loss_fn, SampledSoftmaxLoss):
|
|
2223
|
+
loss = loss_fn(pos_logits, neg_logits)
|
|
2224
|
+
elif isinstance(loss_fn, (BPRLoss, HingeLoss)):
|
|
2225
|
+
loss = loss_fn(pos_logits, neg_logits)
|
|
2226
|
+
elif isinstance(loss_fn, TripletLoss):
|
|
2227
|
+
neg_emb = item_emb.masked_select(~eye.unsqueeze(-1)).view(
|
|
2228
|
+
batch_size, batch_size - 1, item_emb.size(-1)
|
|
2229
|
+
)
|
|
2230
|
+
loss = loss_fn(user_emb, item_emb, neg_emb)
|
|
2231
|
+
elif isinstance(loss_fn, InfoNCELoss) and self.similarity_metric == "dot":
|
|
2232
|
+
neg_emb = item_emb.masked_select(~eye.unsqueeze(-1)).view(
|
|
2233
|
+
batch_size, batch_size - 1, item_emb.size(-1)
|
|
2234
|
+
)
|
|
2235
|
+
loss = loss_fn(user_emb, item_emb, neg_emb)
|
|
2236
|
+
else:
|
|
2237
|
+
targets = torch.arange(batch_size, device=logits.device)
|
|
2238
|
+
loss = F.cross_entropy(logits, targets)
|
|
2239
|
+
|
|
2240
|
+
if self.loss_weights is not None:
|
|
2241
|
+
loss *= float(self.loss_weights[0])
|
|
2149
2242
|
return loss
|
|
2150
2243
|
else:
|
|
2151
2244
|
raise ValueError(f"Unknown training mode: {self.training_mode}")
|
|
@@ -2156,17 +2249,23 @@ class BaseMatchModel(BaseModel):
|
|
|
2156
2249
|
"""Prepare data loader for specific features."""
|
|
2157
2250
|
if isinstance(data, DataLoader):
|
|
2158
2251
|
return data
|
|
2159
|
-
|
|
2160
|
-
|
|
2161
|
-
|
|
2162
|
-
|
|
2163
|
-
|
|
2164
|
-
|
|
2165
|
-
|
|
2166
|
-
|
|
2167
|
-
|
|
2168
|
-
|
|
2169
|
-
|
|
2252
|
+
tensors = build_tensors_from_data(
|
|
2253
|
+
data=data,
|
|
2254
|
+
raw_data=data,
|
|
2255
|
+
features=features,
|
|
2256
|
+
target_columns=[],
|
|
2257
|
+
id_columns=[],
|
|
2258
|
+
)
|
|
2259
|
+
if tensors is None:
|
|
2260
|
+
raise ValueError(
|
|
2261
|
+
"[BaseMatchModel-prepare_feature_data Error] No data available to create DataLoader."
|
|
2262
|
+
)
|
|
2263
|
+
dataset = TensorDictDataset(tensors)
|
|
2264
|
+
return DataLoader(
|
|
2265
|
+
dataset,
|
|
2266
|
+
batch_size=batch_size,
|
|
2267
|
+
shuffle=False,
|
|
2268
|
+
collate_fn=collate_fn,
|
|
2170
2269
|
)
|
|
2171
2270
|
|
|
2172
2271
|
def encode_user(
|