nextrec 0.4.16__py3-none-any.whl → 0.4.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/heads.py +99 -0
- nextrec/basic/loggers.py +5 -5
- nextrec/basic/model.py +217 -88
- nextrec/cli.py +1 -1
- nextrec/data/dataloader.py +93 -95
- nextrec/data/preprocessor.py +108 -46
- nextrec/loss/grad_norm.py +13 -13
- nextrec/models/multi_task/esmm.py +10 -11
- nextrec/models/multi_task/mmoe.py +20 -19
- nextrec/models/multi_task/ple.py +35 -34
- nextrec/models/multi_task/poso.py +23 -21
- nextrec/models/multi_task/share_bottom.py +18 -17
- nextrec/models/ranking/afm.py +4 -3
- nextrec/models/ranking/autoint.py +4 -3
- nextrec/models/ranking/dcn.py +4 -3
- nextrec/models/ranking/dcn_v2.py +4 -3
- nextrec/models/ranking/deepfm.py +4 -3
- nextrec/models/ranking/dien.py +2 -2
- nextrec/models/ranking/din.py +2 -2
- nextrec/models/ranking/eulernet.py +4 -3
- nextrec/models/ranking/ffm.py +4 -3
- nextrec/models/ranking/fibinet.py +2 -2
- nextrec/models/ranking/fm.py +4 -3
- nextrec/models/ranking/lr.py +4 -3
- nextrec/models/ranking/masknet.py +4 -5
- nextrec/models/ranking/pnn.py +5 -4
- nextrec/models/ranking/widedeep.py +8 -8
- nextrec/models/ranking/xdeepfm.py +5 -4
- nextrec/utils/console.py +20 -6
- nextrec/utils/data.py +154 -32
- nextrec/utils/model.py +86 -1
- {nextrec-0.4.16.dist-info → nextrec-0.4.18.dist-info}/METADATA +5 -6
- {nextrec-0.4.16.dist-info → nextrec-0.4.18.dist-info}/RECORD +37 -36
- {nextrec-0.4.16.dist-info → nextrec-0.4.18.dist-info}/WHEEL +0 -0
- {nextrec-0.4.16.dist-info → nextrec-0.4.18.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.16.dist-info → nextrec-0.4.18.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 24/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -38,6 +38,7 @@ from nextrec.basic.features import (
|
|
|
38
38
|
SequenceFeature,
|
|
39
39
|
SparseFeature,
|
|
40
40
|
)
|
|
41
|
+
from nextrec.basic.heads import RetrievalHead
|
|
41
42
|
from nextrec.basic.loggers import TrainingLogger, colorize, format_kv, setup_logger
|
|
42
43
|
from nextrec.basic.metrics import check_user_id, configure_metrics, evaluate_metrics
|
|
43
44
|
from nextrec.basic.session import create_session, resolve_save_path
|
|
@@ -48,6 +49,7 @@ from nextrec.data.dataloader import (
|
|
|
48
49
|
TensorDictDataset,
|
|
49
50
|
build_tensors_from_data,
|
|
50
51
|
)
|
|
52
|
+
from nextrec.utils.data import check_streaming_support
|
|
51
53
|
from nextrec.loss import (
|
|
52
54
|
BPRLoss,
|
|
53
55
|
GradNormLossWeighting,
|
|
@@ -68,6 +70,7 @@ from nextrec.utils.torch_utils import (
|
|
|
68
70
|
init_process_group,
|
|
69
71
|
to_tensor,
|
|
70
72
|
)
|
|
73
|
+
from nextrec.utils.model import compute_ranking_loss
|
|
71
74
|
|
|
72
75
|
|
|
73
76
|
class BaseModel(FeatureSet, nn.Module):
|
|
@@ -87,13 +90,18 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
87
90
|
target: list[str] | str | None = None,
|
|
88
91
|
id_columns: list[str] | str | None = None,
|
|
89
92
|
task: str | list[str] | None = None,
|
|
93
|
+
training_mode: (
|
|
94
|
+
Literal["pointwise", "pairwise", "listwise"]
|
|
95
|
+
| list[Literal["pointwise", "pairwise", "listwise"]]
|
|
96
|
+
) = "pointwise",
|
|
90
97
|
embedding_l1_reg: float = 0.0,
|
|
91
98
|
dense_l1_reg: float = 0.0,
|
|
92
99
|
embedding_l2_reg: float = 0.0,
|
|
93
100
|
dense_l2_reg: float = 0.0,
|
|
94
101
|
device: str = "cpu",
|
|
95
102
|
early_stop_patience: int = 20,
|
|
96
|
-
|
|
103
|
+
early_stop_monitor_task: str | None = None,
|
|
104
|
+
metrics_sample_limit: int | None = 200000,
|
|
97
105
|
session_id: str | None = None,
|
|
98
106
|
callbacks: list[Callback] | None = None,
|
|
99
107
|
distributed: bool = False,
|
|
@@ -112,6 +120,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
112
120
|
target: Target column name. e.g., 'label' or ['label1', 'label2'].
|
|
113
121
|
id_columns: Identifier column name, only need to specify if GAUC is required. e.g., 'user_id'.
|
|
114
122
|
task: Task types, e.g., 'binary', 'regression', or ['binary', 'regression']. If None, falls back to self.default_task.
|
|
123
|
+
training_mode: Training mode for ranking tasks; a single mode or a list per task.
|
|
115
124
|
|
|
116
125
|
embedding_l1_reg: L1 regularization strength for embedding params. e.g., 1e-6.
|
|
117
126
|
dense_l1_reg: L1 regularization strength for dense params. e.g., 1e-5.
|
|
@@ -120,7 +129,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
120
129
|
|
|
121
130
|
device: Torch device string or torch.device. e.g., 'cpu', 'cuda:0'.
|
|
122
131
|
early_stop_patience: Epochs for early stopping. 0 to disable. e.g., 20.
|
|
123
|
-
|
|
132
|
+
early_stop_monitor_task: Task name to monitor for early stopping in multi-task scenario. If None, uses first target. e.g., 'click'.
|
|
133
|
+
metrics_sample_limit: Max samples to keep for training metrics. None disables limit.
|
|
124
134
|
session_id: Session id for logging. If None, a default id with timestamps will be created. e.g., 'session_tutorial'.
|
|
125
135
|
callbacks: List of callback instances. If None, default callbacks will be created. e.g., [EarlyStopper(), CheckpointSaver()].
|
|
126
136
|
|
|
@@ -149,9 +159,11 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
149
159
|
self.session = create_session(session_id)
|
|
150
160
|
self.session_path = self.session.root # pwd/session_id, path for this session
|
|
151
161
|
self.checkpoint_path = os.path.join(
|
|
152
|
-
self.session_path, self.model_name + "_checkpoint.pt"
|
|
162
|
+
self.session_path, self.model_name.upper() + "_checkpoint.pt"
|
|
153
163
|
) # e.g., pwd/session_id/DeepFM_checkpoint.pt
|
|
154
|
-
self.best_path = os.path.join(
|
|
164
|
+
self.best_path = os.path.join(
|
|
165
|
+
self.session_path, self.model_name.upper() + "_best.pt"
|
|
166
|
+
)
|
|
155
167
|
self.features_config_path = os.path.join(
|
|
156
168
|
self.session_path, "features_config.pkl"
|
|
157
169
|
)
|
|
@@ -161,6 +173,22 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
161
173
|
|
|
162
174
|
self.task = self.default_task if task is None else task
|
|
163
175
|
self.nums_task = len(self.task) if isinstance(self.task, list) else 1
|
|
176
|
+
if isinstance(training_mode, list):
|
|
177
|
+
if len(training_mode) != self.nums_task:
|
|
178
|
+
raise ValueError(
|
|
179
|
+
"[BaseModel-init Error] training_mode list length must match number of tasks."
|
|
180
|
+
)
|
|
181
|
+
self.training_modes = list(training_mode)
|
|
182
|
+
else:
|
|
183
|
+
self.training_modes = [training_mode] * self.nums_task
|
|
184
|
+
for mode in self.training_modes:
|
|
185
|
+
if mode not in {"pointwise", "pairwise", "listwise"}:
|
|
186
|
+
raise ValueError(
|
|
187
|
+
"[BaseModel-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
|
|
188
|
+
)
|
|
189
|
+
self.training_mode = (
|
|
190
|
+
self.training_modes if self.nums_task > 1 else self.training_modes[0]
|
|
191
|
+
)
|
|
164
192
|
|
|
165
193
|
self.embedding_l1_reg = embedding_l1_reg
|
|
166
194
|
self.dense_l1_reg = dense_l1_reg
|
|
@@ -171,9 +199,10 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
171
199
|
self.loss_weight = None
|
|
172
200
|
|
|
173
201
|
self.early_stop_patience = early_stop_patience
|
|
202
|
+
self.early_stop_monitor_task = early_stop_monitor_task
|
|
174
203
|
# max samples to keep for training metrics, in case of large training set
|
|
175
|
-
self.
|
|
176
|
-
None if
|
|
204
|
+
self.metrics_sample_limit = (
|
|
205
|
+
None if metrics_sample_limit is None else int(metrics_sample_limit)
|
|
177
206
|
)
|
|
178
207
|
self.max_gradient_norm = 1.0
|
|
179
208
|
self.logger_initialized = False
|
|
@@ -397,6 +426,33 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
397
426
|
Use "grad_norm" or {"method": "grad_norm", ...} to enable GradNorm for multi-task loss balancing.
|
|
398
427
|
callbacks: Additional callbacks to add to the existing callback list. e.g., [EarlyStopper(), CheckpointSaver()].
|
|
399
428
|
"""
|
|
429
|
+
default_losses = {
|
|
430
|
+
"pointwise": "bce",
|
|
431
|
+
"pairwise": "bpr",
|
|
432
|
+
"listwise": "listnet",
|
|
433
|
+
}
|
|
434
|
+
effective_loss = loss
|
|
435
|
+
if effective_loss is None:
|
|
436
|
+
loss_list = [default_losses[mode] for mode in self.training_modes]
|
|
437
|
+
elif isinstance(effective_loss, list):
|
|
438
|
+
if not effective_loss:
|
|
439
|
+
loss_list = [default_losses[mode] for mode in self.training_modes]
|
|
440
|
+
else:
|
|
441
|
+
if len(effective_loss) != self.nums_task:
|
|
442
|
+
raise ValueError(
|
|
443
|
+
f"[BaseModel-compile Error] Number of loss functions ({len(effective_loss)}) must match number of tasks ({self.nums_task})."
|
|
444
|
+
)
|
|
445
|
+
loss_list = list(effective_loss)
|
|
446
|
+
else:
|
|
447
|
+
loss_list = [effective_loss] * self.nums_task
|
|
448
|
+
|
|
449
|
+
for idx, mode in enumerate(self.training_modes):
|
|
450
|
+
if isinstance(loss_list[idx], str) and loss_list[idx] in {
|
|
451
|
+
"bce",
|
|
452
|
+
"binary_crossentropy",
|
|
453
|
+
}:
|
|
454
|
+
if mode in {"pairwise", "listwise"}:
|
|
455
|
+
loss_list[idx] = default_losses[mode]
|
|
400
456
|
if loss_params is None:
|
|
401
457
|
self.loss_params = {}
|
|
402
458
|
else:
|
|
@@ -426,16 +482,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
426
482
|
else None
|
|
427
483
|
)
|
|
428
484
|
|
|
429
|
-
self.loss_config =
|
|
485
|
+
self.loss_config = loss_list if self.nums_task > 1 else loss_list[0]
|
|
430
486
|
self.loss_params = loss_params or {}
|
|
431
|
-
if isinstance(loss, list):
|
|
432
|
-
if len(loss) != self.nums_task:
|
|
433
|
-
raise ValueError(
|
|
434
|
-
f"[BaseModel-compile Error] Number of loss functions ({len(loss)}) must match number of tasks ({self.nums_task})."
|
|
435
|
-
)
|
|
436
|
-
loss_list = list(loss)
|
|
437
|
-
else:
|
|
438
|
-
loss_list = [loss] * self.nums_task
|
|
439
487
|
if isinstance(self.loss_params, dict):
|
|
440
488
|
loss_params_list = [self.loss_params] * self.nums_task
|
|
441
489
|
else:
|
|
@@ -456,7 +504,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
456
504
|
"[BaseModel-compile Error] GradNorm requires multi-task setup."
|
|
457
505
|
)
|
|
458
506
|
self.grad_norm = GradNormLossWeighting(
|
|
459
|
-
|
|
507
|
+
nums_task=self.nums_task, device=self.device
|
|
460
508
|
)
|
|
461
509
|
self.loss_weights = None
|
|
462
510
|
elif (
|
|
@@ -469,7 +517,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
469
517
|
grad_norm_params = dict(loss_weights)
|
|
470
518
|
grad_norm_params.pop("method", None)
|
|
471
519
|
self.grad_norm = GradNormLossWeighting(
|
|
472
|
-
|
|
520
|
+
nums_task=self.nums_task, device=self.device, **grad_norm_params
|
|
473
521
|
)
|
|
474
522
|
self.loss_weights = None
|
|
475
523
|
elif loss_weights is None:
|
|
@@ -507,6 +555,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
507
555
|
raise ValueError(
|
|
508
556
|
"[BaseModel-compute_loss Error] Ground truth labels (y_true) are required."
|
|
509
557
|
)
|
|
558
|
+
# single-task
|
|
510
559
|
if self.nums_task == 1:
|
|
511
560
|
if y_pred.dim() == 1:
|
|
512
561
|
y_pred = y_pred.view(-1, 1)
|
|
@@ -514,16 +563,30 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
514
563
|
y_true = y_true.view(-1, 1)
|
|
515
564
|
if y_pred.shape != y_true.shape:
|
|
516
565
|
raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
|
|
566
|
+
loss_fn = self.loss_fn[0] if getattr(self, "loss_fn", None) else None
|
|
567
|
+
if loss_fn is None:
|
|
568
|
+
raise ValueError(
|
|
569
|
+
"[BaseModel-compute_loss Error] Loss function is not configured. Call compile() first."
|
|
570
|
+
)
|
|
571
|
+
mode = self.training_modes[0]
|
|
517
572
|
task_dim = (
|
|
518
573
|
self.task_dims[0] if hasattr(self, "task_dims") else y_pred.shape[1] # type: ignore
|
|
519
574
|
)
|
|
520
|
-
if
|
|
521
|
-
loss =
|
|
575
|
+
if mode in {"pairwise", "listwise"}:
|
|
576
|
+
loss = compute_ranking_loss(
|
|
577
|
+
training_mode=mode,
|
|
578
|
+
loss_fn=loss_fn,
|
|
579
|
+
y_pred=y_pred,
|
|
580
|
+
y_true=y_true,
|
|
581
|
+
)
|
|
582
|
+
elif task_dim == 1:
|
|
583
|
+
loss = loss_fn(y_pred.view(-1), y_true.view(-1))
|
|
522
584
|
else:
|
|
523
|
-
loss =
|
|
585
|
+
loss = loss_fn(y_pred, y_true)
|
|
524
586
|
if self.loss_weights is not None:
|
|
525
587
|
loss *= self.loss_weights[0]
|
|
526
588
|
return loss
|
|
589
|
+
|
|
527
590
|
# multi-task
|
|
528
591
|
if y_pred.shape != y_true.shape:
|
|
529
592
|
raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
|
|
@@ -536,7 +599,16 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
536
599
|
for i, (start, end) in enumerate(slices): # type: ignore
|
|
537
600
|
y_pred_i = y_pred[:, start:end]
|
|
538
601
|
y_true_i = y_true[:, start:end]
|
|
539
|
-
|
|
602
|
+
mode = self.training_modes[i]
|
|
603
|
+
if mode in {"pairwise", "listwise"}:
|
|
604
|
+
task_loss = compute_ranking_loss(
|
|
605
|
+
training_mode=mode,
|
|
606
|
+
loss_fn=self.loss_fn[i],
|
|
607
|
+
y_pred=y_pred_i,
|
|
608
|
+
y_true=y_true_i,
|
|
609
|
+
)
|
|
610
|
+
else:
|
|
611
|
+
task_loss = self.loss_fn[i](y_pred_i, y_true_i)
|
|
540
612
|
task_losses.append(task_loss)
|
|
541
613
|
if self.grad_norm is not None:
|
|
542
614
|
if self.grad_norm_shared_params is None:
|
|
@@ -602,8 +674,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
602
674
|
user_id_column: str | None = None,
|
|
603
675
|
validation_split: float | None = None,
|
|
604
676
|
num_workers: int = 0,
|
|
605
|
-
|
|
606
|
-
|
|
677
|
+
use_tensorboard: bool = True,
|
|
678
|
+
auto_ddp_sampler: bool = True,
|
|
607
679
|
log_interval: int = 1,
|
|
608
680
|
):
|
|
609
681
|
"""
|
|
@@ -619,8 +691,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
619
691
|
user_id_column: Column name for GAUC-style metrics;.
|
|
620
692
|
validation_split: Ratio to split training data when valid_data is None.
|
|
621
693
|
num_workers: DataLoader worker count.
|
|
622
|
-
|
|
623
|
-
|
|
694
|
+
use_tensorboard: Enable tensorboard logging.
|
|
695
|
+
auto_ddp_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
|
|
624
696
|
log_interval: Log validation metrics every N epochs (still computes metrics each epoch).
|
|
625
697
|
|
|
626
698
|
Notes:
|
|
@@ -662,7 +734,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
662
734
|
setup_logger(session_id=self.session_id)
|
|
663
735
|
self.logger_initialized = True
|
|
664
736
|
self.training_logger = (
|
|
665
|
-
TrainingLogger(session=self.session,
|
|
737
|
+
TrainingLogger(session=self.session, use_tensorboard=use_tensorboard)
|
|
666
738
|
if self.is_main_process
|
|
667
739
|
else None
|
|
668
740
|
)
|
|
@@ -680,18 +752,21 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
680
752
|
if self.nums_task == 1:
|
|
681
753
|
monitor_metric = f"val_{self.metrics[0]}"
|
|
682
754
|
else:
|
|
683
|
-
|
|
755
|
+
# Determine which task to monitor for early stopping
|
|
756
|
+
monitor_task = self.early_stop_monitor_task
|
|
757
|
+
if monitor_task is None:
|
|
758
|
+
monitor_task = self.target_columns[0]
|
|
759
|
+
elif monitor_task not in self.target_columns:
|
|
760
|
+
raise ValueError(
|
|
761
|
+
f"[BaseModel-fit Error] early_stop_monitor_task '{monitor_task}' not found in target_columns {self.target_columns}."
|
|
762
|
+
)
|
|
763
|
+
monitor_metric = f"val_{self.metrics[0]}_{monitor_task}"
|
|
684
764
|
|
|
685
765
|
existing_callbacks = self.callbacks.callbacks
|
|
686
|
-
has_early_stop = any(isinstance(cb, EarlyStopper) for cb in existing_callbacks)
|
|
687
|
-
has_checkpoint = any(
|
|
688
|
-
isinstance(cb, CheckpointSaver) for cb in existing_callbacks
|
|
689
|
-
)
|
|
690
|
-
has_lr_scheduler = any(
|
|
691
|
-
isinstance(cb, LearningRateScheduler) for cb in existing_callbacks
|
|
692
|
-
)
|
|
693
766
|
|
|
694
|
-
if self.early_stop_patience > 0 and not
|
|
767
|
+
if self.early_stop_patience > 0 and not any(
|
|
768
|
+
isinstance(cb, EarlyStopper) for cb in existing_callbacks
|
|
769
|
+
):
|
|
695
770
|
self.callbacks.append(
|
|
696
771
|
EarlyStopper(
|
|
697
772
|
monitor=monitor_metric,
|
|
@@ -702,7 +777,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
702
777
|
)
|
|
703
778
|
)
|
|
704
779
|
|
|
705
|
-
if self.is_main_process and not
|
|
780
|
+
if self.is_main_process and not any(
|
|
781
|
+
isinstance(cb, CheckpointSaver) for cb in existing_callbacks
|
|
782
|
+
):
|
|
706
783
|
self.callbacks.append(
|
|
707
784
|
CheckpointSaver(
|
|
708
785
|
best_path=self.best_path,
|
|
@@ -714,7 +791,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
714
791
|
)
|
|
715
792
|
)
|
|
716
793
|
|
|
717
|
-
if self.scheduler_fn is not None and not
|
|
794
|
+
if self.scheduler_fn is not None and not any(
|
|
795
|
+
isinstance(cb, LearningRateScheduler) for cb in existing_callbacks
|
|
796
|
+
):
|
|
718
797
|
self.callbacks.append(
|
|
719
798
|
LearningRateScheduler(
|
|
720
799
|
scheduler=self.scheduler_fn,
|
|
@@ -737,16 +816,16 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
737
816
|
self.stop_training = False
|
|
738
817
|
self.best_checkpoint_path = self.best_path
|
|
739
818
|
use_ddp_sampler = (
|
|
740
|
-
|
|
819
|
+
auto_ddp_sampler
|
|
741
820
|
and self.distributed
|
|
742
821
|
and dist.is_available()
|
|
743
822
|
and dist.is_initialized()
|
|
744
823
|
)
|
|
745
824
|
|
|
746
|
-
if not
|
|
825
|
+
if not auto_ddp_sampler and self.distributed and self.is_main_process:
|
|
747
826
|
logging.info(
|
|
748
827
|
colorize(
|
|
749
|
-
"[Distributed Info]
|
|
828
|
+
"[Distributed Info] auto_ddp_sampler=False; assuming data is already sharded per rank.",
|
|
750
829
|
color="yellow",
|
|
751
830
|
)
|
|
752
831
|
)
|
|
@@ -825,12 +904,12 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
825
904
|
# If split-based loader was built without sampler, attach here when enabled
|
|
826
905
|
if (
|
|
827
906
|
self.distributed
|
|
828
|
-
and
|
|
907
|
+
and auto_ddp_sampler
|
|
829
908
|
and isinstance(train_loader, DataLoader)
|
|
830
909
|
and train_sampler is None
|
|
831
910
|
):
|
|
832
911
|
raise NotImplementedError(
|
|
833
|
-
"[BaseModel-fit Error]
|
|
912
|
+
"[BaseModel-fit Error] auto_ddp_sampler with pre-defined DataLoader is not supported yet."
|
|
834
913
|
)
|
|
835
914
|
# train_loader, train_sampler = add_distributed_sampler(train_loader, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
|
|
836
915
|
|
|
@@ -840,7 +919,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
840
919
|
needs_user_ids=self.needs_user_ids,
|
|
841
920
|
user_id_column=user_id_column,
|
|
842
921
|
num_workers=num_workers,
|
|
843
|
-
|
|
922
|
+
auto_ddp_sampler=auto_ddp_sampler,
|
|
844
923
|
)
|
|
845
924
|
try:
|
|
846
925
|
self.steps_per_epoch = len(train_loader)
|
|
@@ -862,7 +941,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
862
941
|
logging.info("")
|
|
863
942
|
tb_dir = (
|
|
864
943
|
self.training_logger.tensorboard_logdir
|
|
865
|
-
if self.training_logger and self.training_logger.
|
|
944
|
+
if self.training_logger and self.training_logger.use_tensorboard
|
|
866
945
|
else None
|
|
867
946
|
)
|
|
868
947
|
if tb_dir:
|
|
@@ -1054,7 +1133,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1054
1133
|
y_true_list = []
|
|
1055
1134
|
y_pred_list = []
|
|
1056
1135
|
collect_metrics = getattr(self, "collect_train_metrics", True)
|
|
1057
|
-
max_samples = getattr(self, "
|
|
1136
|
+
max_samples = getattr(self, "metrics_sample_limit", None)
|
|
1058
1137
|
collected_samples = 0
|
|
1059
1138
|
metrics_capped = False
|
|
1060
1139
|
|
|
@@ -1183,14 +1262,14 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1183
1262
|
needs_user_ids: bool,
|
|
1184
1263
|
user_id_column: str | None = "user_id",
|
|
1185
1264
|
num_workers: int = 0,
|
|
1186
|
-
|
|
1265
|
+
auto_ddp_sampler: bool = True,
|
|
1187
1266
|
) -> tuple[DataLoader | None, np.ndarray | None]:
|
|
1188
1267
|
if valid_data is None:
|
|
1189
1268
|
return None, None
|
|
1190
1269
|
if isinstance(valid_data, DataLoader):
|
|
1191
|
-
if
|
|
1270
|
+
if auto_ddp_sampler and self.distributed:
|
|
1192
1271
|
raise NotImplementedError(
|
|
1193
|
-
"[BaseModel-prepare_validation_data Error]
|
|
1272
|
+
"[BaseModel-prepare_validation_data Error] auto_ddp_sampler with pre-defined DataLoader is not supported yet."
|
|
1194
1273
|
)
|
|
1195
1274
|
# valid_loader, _ = add_distributed_sampler(valid_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=False, drop_last=False, default_batch_size=batch_size, is_main_process=self.is_main_process)
|
|
1196
1275
|
else:
|
|
@@ -1199,7 +1278,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1199
1278
|
valid_sampler = None
|
|
1200
1279
|
valid_loader, valid_dataset = self.prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, return_dataset=True) # type: ignore
|
|
1201
1280
|
if (
|
|
1202
|
-
|
|
1281
|
+
auto_ddp_sampler
|
|
1203
1282
|
and self.distributed
|
|
1204
1283
|
and valid_dataset is not None
|
|
1205
1284
|
and dist.is_available()
|
|
@@ -1372,11 +1451,11 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1372
1451
|
data: str | dict | pd.DataFrame | DataLoader,
|
|
1373
1452
|
batch_size: int = 32,
|
|
1374
1453
|
save_path: str | os.PathLike | None = None,
|
|
1375
|
-
save_format:
|
|
1454
|
+
save_format: str = "csv",
|
|
1376
1455
|
include_ids: bool | None = None,
|
|
1377
1456
|
id_columns: str | list[str] | None = None,
|
|
1378
1457
|
return_dataframe: bool = True,
|
|
1379
|
-
|
|
1458
|
+
stream_chunk_size: int = 10000,
|
|
1380
1459
|
num_workers: int = 0,
|
|
1381
1460
|
) -> pd.DataFrame | np.ndarray | Path | None:
|
|
1382
1461
|
"""
|
|
@@ -1391,7 +1470,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1391
1470
|
include_ids: Whether to include ID columns in the output; if None, includes if id_columns are set.
|
|
1392
1471
|
id_columns: Column name(s) to use as IDs; if None, uses model's id_columns.
|
|
1393
1472
|
return_dataframe: Whether to return predictions as a pandas DataFrame; if False, returns a NumPy array.
|
|
1394
|
-
|
|
1473
|
+
stream_chunk_size: Number of rows per chunk when using streaming mode for large datasets.
|
|
1395
1474
|
num_workers: DataLoader worker count.
|
|
1396
1475
|
"""
|
|
1397
1476
|
self.eval()
|
|
@@ -1412,7 +1491,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1412
1491
|
save_path=save_path,
|
|
1413
1492
|
save_format=save_format,
|
|
1414
1493
|
include_ids=include_ids,
|
|
1415
|
-
|
|
1494
|
+
stream_chunk_size=stream_chunk_size,
|
|
1416
1495
|
return_dataframe=return_dataframe,
|
|
1417
1496
|
id_columns=predict_id_columns,
|
|
1418
1497
|
)
|
|
@@ -1438,7 +1517,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1438
1517
|
batch_size=batch_size,
|
|
1439
1518
|
shuffle=False,
|
|
1440
1519
|
streaming=True,
|
|
1441
|
-
chunk_size=
|
|
1520
|
+
chunk_size=stream_chunk_size,
|
|
1442
1521
|
)
|
|
1443
1522
|
else:
|
|
1444
1523
|
data_loader = self.prepare_data_loader(
|
|
@@ -1516,11 +1595,18 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1516
1595
|
else y_pred_all
|
|
1517
1596
|
)
|
|
1518
1597
|
if save_path is not None:
|
|
1519
|
-
|
|
1520
|
-
|
|
1521
|
-
|
|
1598
|
+
# Check streaming write support
|
|
1599
|
+
if not check_streaming_support(save_format):
|
|
1600
|
+
logging.warning(
|
|
1601
|
+
f"[BaseModel-predict Warning] Format '{save_format}' does not support streaming writes. "
|
|
1602
|
+
"The entire result will be saved at once. Use csv or parquet for large datasets."
|
|
1522
1603
|
)
|
|
1523
|
-
|
|
1604
|
+
|
|
1605
|
+
# Get file extension from format
|
|
1606
|
+
from nextrec.utils.data import FILE_FORMAT_CONFIG
|
|
1607
|
+
|
|
1608
|
+
suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
|
|
1609
|
+
|
|
1524
1610
|
target_path = resolve_save_path(
|
|
1525
1611
|
path=save_path,
|
|
1526
1612
|
default_dir=self.session.predictions_dir,
|
|
@@ -1539,10 +1625,21 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1539
1625
|
f"[BaseModel-predict Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_to_save)})."
|
|
1540
1626
|
)
|
|
1541
1627
|
df_to_save = pd.concat([id_df, df_to_save], axis=1)
|
|
1628
|
+
|
|
1629
|
+
# Save based on format
|
|
1542
1630
|
if save_format == "csv":
|
|
1543
1631
|
df_to_save.to_csv(target_path, index=False)
|
|
1544
|
-
|
|
1632
|
+
elif save_format == "parquet":
|
|
1545
1633
|
df_to_save.to_parquet(target_path, index=False)
|
|
1634
|
+
elif save_format == "feather":
|
|
1635
|
+
df_to_save.to_feather(target_path)
|
|
1636
|
+
elif save_format == "excel":
|
|
1637
|
+
df_to_save.to_excel(target_path, index=False)
|
|
1638
|
+
elif save_format == "hdf5":
|
|
1639
|
+
df_to_save.to_hdf(target_path, key="predictions", mode="w")
|
|
1640
|
+
else:
|
|
1641
|
+
raise ValueError(f"Unsupported save format: {save_format}")
|
|
1642
|
+
|
|
1546
1643
|
logging.info(
|
|
1547
1644
|
colorize(f"Predictions saved to: {target_path}", color="green")
|
|
1548
1645
|
)
|
|
@@ -1553,9 +1650,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1553
1650
|
data: str | dict | pd.DataFrame | DataLoader,
|
|
1554
1651
|
batch_size: int,
|
|
1555
1652
|
save_path: str | os.PathLike,
|
|
1556
|
-
save_format:
|
|
1653
|
+
save_format: str,
|
|
1557
1654
|
include_ids: bool,
|
|
1558
|
-
|
|
1655
|
+
stream_chunk_size: int,
|
|
1559
1656
|
return_dataframe: bool,
|
|
1560
1657
|
id_columns: list[str] | None = None,
|
|
1561
1658
|
) -> pd.DataFrame | Path:
|
|
@@ -1572,7 +1669,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1572
1669
|
batch_size=batch_size,
|
|
1573
1670
|
shuffle=False,
|
|
1574
1671
|
streaming=True,
|
|
1575
|
-
chunk_size=
|
|
1672
|
+
chunk_size=stream_chunk_size,
|
|
1576
1673
|
)
|
|
1577
1674
|
elif not isinstance(data, DataLoader):
|
|
1578
1675
|
data_loader = self.prepare_data_loader(
|
|
@@ -1594,7 +1691,17 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1594
1691
|
"When using streaming mode, set num_workers=0 to avoid reading data multiple times."
|
|
1595
1692
|
)
|
|
1596
1693
|
|
|
1597
|
-
|
|
1694
|
+
# Check streaming support and prepare file path
|
|
1695
|
+
if not check_streaming_support(save_format):
|
|
1696
|
+
logging.warning(
|
|
1697
|
+
f"[Predict Streaming Warning] Format '{save_format}' does not support streaming writes. "
|
|
1698
|
+
"Results will be collected in memory and saved at the end. Use csv or parquet for true streaming."
|
|
1699
|
+
)
|
|
1700
|
+
|
|
1701
|
+
from nextrec.utils.data import FILE_FORMAT_CONFIG
|
|
1702
|
+
|
|
1703
|
+
suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
|
|
1704
|
+
|
|
1598
1705
|
target_path = resolve_save_path(
|
|
1599
1706
|
path=save_path,
|
|
1600
1707
|
default_dir=self.session.predictions_dir,
|
|
@@ -1605,9 +1712,10 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1605
1712
|
target_path.parent.mkdir(parents=True, exist_ok=True)
|
|
1606
1713
|
header_written = target_path.exists() and target_path.stat().st_size > 0
|
|
1607
1714
|
parquet_writer = None
|
|
1608
|
-
|
|
1609
1715
|
pred_columns = None
|
|
1610
|
-
collected_frames =
|
|
1716
|
+
collected_frames = (
|
|
1717
|
+
[]
|
|
1718
|
+
) # used when return_dataframe=True or for non-streaming formats
|
|
1611
1719
|
|
|
1612
1720
|
with torch.no_grad():
|
|
1613
1721
|
for batch_data in progress(data_loader, description="Predicting"):
|
|
@@ -1649,27 +1757,48 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1649
1757
|
)
|
|
1650
1758
|
df_batch = pd.concat([id_df, df_batch], axis=1)
|
|
1651
1759
|
|
|
1760
|
+
# Streaming save based on format
|
|
1652
1761
|
if save_format == "csv":
|
|
1653
1762
|
df_batch.to_csv(
|
|
1654
1763
|
target_path, mode="a", header=not header_written, index=False
|
|
1655
1764
|
)
|
|
1656
1765
|
header_written = True
|
|
1657
|
-
|
|
1766
|
+
elif save_format == "parquet":
|
|
1658
1767
|
try:
|
|
1659
1768
|
import pyarrow as pa
|
|
1660
1769
|
import pyarrow.parquet as pq
|
|
1661
1770
|
except ImportError as exc: # pragma: no cover
|
|
1662
1771
|
raise ImportError(
|
|
1663
|
-
"[BaseModel-predict-streaming Error] Parquet streaming save requires pyarrow
|
|
1772
|
+
"[BaseModel-predict-streaming Error] Parquet streaming save requires pyarrow."
|
|
1664
1773
|
) from exc
|
|
1665
1774
|
table = pa.Table.from_pandas(df_batch, preserve_index=False)
|
|
1666
1775
|
if parquet_writer is None:
|
|
1667
1776
|
parquet_writer = pq.ParquetWriter(target_path, table.schema)
|
|
1668
1777
|
parquet_writer.write_table(table)
|
|
1669
|
-
|
|
1778
|
+
else:
|
|
1779
|
+
# Non-streaming formats: collect all data
|
|
1670
1780
|
collected_frames.append(df_batch)
|
|
1781
|
+
|
|
1782
|
+
if return_dataframe:
|
|
1783
|
+
if (
|
|
1784
|
+
save_format in ["csv", "parquet"]
|
|
1785
|
+
and df_batch not in collected_frames
|
|
1786
|
+
):
|
|
1787
|
+
collected_frames.append(df_batch)
|
|
1788
|
+
|
|
1789
|
+
# Close writers
|
|
1671
1790
|
if parquet_writer is not None:
|
|
1672
1791
|
parquet_writer.close()
|
|
1792
|
+
# For non-streaming formats, save collected data
|
|
1793
|
+
if save_format in ["feather", "excel", "hdf5"] and collected_frames:
|
|
1794
|
+
combined_df = pd.concat(collected_frames, ignore_index=True)
|
|
1795
|
+
if save_format == "feather":
|
|
1796
|
+
combined_df.to_feather(target_path)
|
|
1797
|
+
elif save_format == "excel":
|
|
1798
|
+
combined_df.to_excel(target_path, index=False)
|
|
1799
|
+
elif save_format == "hdf5":
|
|
1800
|
+
combined_df.to_hdf(target_path, key="predictions", mode="w")
|
|
1801
|
+
|
|
1673
1802
|
logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
|
|
1674
1803
|
if return_dataframe:
|
|
1675
1804
|
return (
|
|
@@ -1690,7 +1819,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1690
1819
|
target_path = resolve_save_path(
|
|
1691
1820
|
path=save_path,
|
|
1692
1821
|
default_dir=self.session_path,
|
|
1693
|
-
default_name=self.model_name,
|
|
1822
|
+
default_name=self.model_name.upper(),
|
|
1694
1823
|
suffix=".pt",
|
|
1695
1824
|
add_timestamp=add_timestamp,
|
|
1696
1825
|
)
|
|
@@ -1844,7 +1973,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1844
1973
|
logger.info("")
|
|
1845
1974
|
logger.info(
|
|
1846
1975
|
colorize(
|
|
1847
|
-
f"Model Summary: {self.model_name}",
|
|
1976
|
+
f"Model Summary: {self.model_name.upper()}",
|
|
1977
|
+
color="bright_blue",
|
|
1978
|
+
bold=True,
|
|
1848
1979
|
)
|
|
1849
1980
|
)
|
|
1850
1981
|
logger.info("")
|
|
@@ -1975,7 +2106,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1975
2106
|
logger.info("Other Settings:")
|
|
1976
2107
|
logger.info(f" Early Stop Patience: {self.early_stop_patience}")
|
|
1977
2108
|
logger.info(f" Max Gradient Norm: {self.max_gradient_norm}")
|
|
1978
|
-
logger.info(f" Max Metrics Samples: {self.
|
|
2109
|
+
logger.info(f" Max Metrics Samples: {self.metrics_sample_limit}")
|
|
1979
2110
|
logger.info(f" Session ID: {self.session_id}")
|
|
1980
2111
|
logger.info(f" Features Config Path: {self.features_config_path}")
|
|
1981
2112
|
logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
|
|
@@ -2115,6 +2246,12 @@ class BaseMatchModel(BaseModel):
|
|
|
2115
2246
|
)
|
|
2116
2247
|
self.user_feature_names = {feature.name for feature in self.user_features_all}
|
|
2117
2248
|
self.item_feature_names = {feature.name for feature in self.item_features_all}
|
|
2249
|
+
self.head = RetrievalHead(
|
|
2250
|
+
similarity_metric=self.similarity_metric,
|
|
2251
|
+
temperature=self.temperature,
|
|
2252
|
+
training_mode=self.training_mode,
|
|
2253
|
+
apply_sigmoid=True,
|
|
2254
|
+
)
|
|
2118
2255
|
|
|
2119
2256
|
def compile(
|
|
2120
2257
|
self,
|
|
@@ -2139,7 +2276,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2139
2276
|
"""
|
|
2140
2277
|
if self.training_mode not in self.support_training_modes:
|
|
2141
2278
|
raise ValueError(
|
|
2142
|
-
f"{self.model_name} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
|
|
2279
|
+
f"{self.model_name.upper()} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
|
|
2143
2280
|
)
|
|
2144
2281
|
|
|
2145
2282
|
default_loss_by_mode: dict[str, str] = {
|
|
@@ -2244,15 +2381,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2244
2381
|
user_emb = self.user_tower(user_input) # [B, D]
|
|
2245
2382
|
item_emb = self.item_tower(item_input) # [B, D]
|
|
2246
2383
|
|
|
2247
|
-
|
|
2248
|
-
return user_emb, item_emb
|
|
2249
|
-
|
|
2250
|
-
similarity = self.compute_similarity(user_emb, item_emb) # [B]
|
|
2251
|
-
|
|
2252
|
-
if self.training_mode == "pointwise":
|
|
2253
|
-
return torch.sigmoid(similarity)
|
|
2254
|
-
else:
|
|
2255
|
-
return similarity
|
|
2384
|
+
return self.head(user_emb, item_emb, similarity_fn=self.compute_similarity)
|
|
2256
2385
|
|
|
2257
2386
|
def compute_loss(self, y_pred, y_true):
|
|
2258
2387
|
if self.training_mode == "pointwise":
|
|
@@ -2308,7 +2437,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2308
2437
|
features: list,
|
|
2309
2438
|
batch_size: int,
|
|
2310
2439
|
num_workers: int = 0,
|
|
2311
|
-
|
|
2440
|
+
stream_chunk_size: int = 10000,
|
|
2312
2441
|
) -> DataLoader:
|
|
2313
2442
|
"""Prepare data loader for specific features."""
|
|
2314
2443
|
if isinstance(data, DataLoader):
|
|
@@ -2329,7 +2458,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2329
2458
|
batch_size=batch_size,
|
|
2330
2459
|
shuffle=False,
|
|
2331
2460
|
streaming=True,
|
|
2332
|
-
chunk_size=
|
|
2461
|
+
chunk_size=stream_chunk_size,
|
|
2333
2462
|
num_workers=num_workers,
|
|
2334
2463
|
)
|
|
2335
2464
|
tensors = build_tensors_from_data(
|
|
@@ -2382,7 +2511,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2382
2511
|
),
|
|
2383
2512
|
batch_size: int = 512,
|
|
2384
2513
|
num_workers: int = 0,
|
|
2385
|
-
|
|
2514
|
+
stream_chunk_size: int = 10000,
|
|
2386
2515
|
) -> np.ndarray:
|
|
2387
2516
|
self.eval()
|
|
2388
2517
|
data_loader = self.prepare_feature_data(
|
|
@@ -2390,7 +2519,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2390
2519
|
self.user_features_all,
|
|
2391
2520
|
batch_size,
|
|
2392
2521
|
num_workers=num_workers,
|
|
2393
|
-
|
|
2522
|
+
stream_chunk_size=stream_chunk_size,
|
|
2394
2523
|
)
|
|
2395
2524
|
|
|
2396
2525
|
embeddings_list = []
|
|
@@ -2416,7 +2545,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2416
2545
|
),
|
|
2417
2546
|
batch_size: int = 512,
|
|
2418
2547
|
num_workers: int = 0,
|
|
2419
|
-
|
|
2548
|
+
stream_chunk_size: int = 10000,
|
|
2420
2549
|
) -> np.ndarray:
|
|
2421
2550
|
self.eval()
|
|
2422
2551
|
data_loader = self.prepare_feature_data(
|
|
@@ -2424,7 +2553,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2424
2553
|
self.item_features_all,
|
|
2425
2554
|
batch_size,
|
|
2426
2555
|
num_workers=num_workers,
|
|
2427
|
-
|
|
2556
|
+
stream_chunk_size=stream_chunk_size,
|
|
2428
2557
|
)
|
|
2429
2558
|
|
|
2430
2559
|
embeddings_list = []
|