nextrec 0.4.17__py3-none-any.whl → 0.4.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/heads.py +1 -3
- nextrec/basic/loggers.py +5 -5
- nextrec/basic/model.py +210 -82
- nextrec/cli.py +5 -5
- nextrec/data/dataloader.py +93 -95
- nextrec/data/preprocessor.py +108 -46
- nextrec/loss/grad_norm.py +13 -13
- nextrec/models/multi_task/esmm.py +9 -11
- nextrec/models/multi_task/mmoe.py +18 -18
- nextrec/models/multi_task/ple.py +33 -33
- nextrec/models/multi_task/poso.py +21 -20
- nextrec/models/multi_task/share_bottom.py +16 -16
- nextrec/models/ranking/afm.py +2 -2
- nextrec/models/ranking/autoint.py +2 -2
- nextrec/models/ranking/dcn.py +2 -2
- nextrec/models/ranking/dcn_v2.py +2 -2
- nextrec/models/ranking/deepfm.py +2 -2
- nextrec/models/ranking/eulernet.py +2 -2
- nextrec/models/ranking/ffm.py +2 -2
- nextrec/models/ranking/fm.py +2 -2
- nextrec/models/ranking/lr.py +2 -2
- nextrec/models/ranking/masknet.py +2 -4
- nextrec/models/ranking/pnn.py +3 -3
- nextrec/models/ranking/widedeep.py +6 -7
- nextrec/models/ranking/xdeepfm.py +3 -3
- nextrec/utils/console.py +1 -1
- nextrec/utils/data.py +154 -32
- nextrec/utils/model.py +86 -1
- {nextrec-0.4.17.dist-info → nextrec-0.4.19.dist-info}/METADATA +8 -7
- {nextrec-0.4.17.dist-info → nextrec-0.4.19.dist-info}/RECORD +34 -34
- {nextrec-0.4.17.dist-info → nextrec-0.4.19.dist-info}/WHEEL +0 -0
- {nextrec-0.4.17.dist-info → nextrec-0.4.19.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.17.dist-info → nextrec-0.4.19.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 24/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -49,6 +49,7 @@ from nextrec.data.dataloader import (
|
|
|
49
49
|
TensorDictDataset,
|
|
50
50
|
build_tensors_from_data,
|
|
51
51
|
)
|
|
52
|
+
from nextrec.utils.data import check_streaming_support
|
|
52
53
|
from nextrec.loss import (
|
|
53
54
|
BPRLoss,
|
|
54
55
|
GradNormLossWeighting,
|
|
@@ -69,6 +70,7 @@ from nextrec.utils.torch_utils import (
|
|
|
69
70
|
init_process_group,
|
|
70
71
|
to_tensor,
|
|
71
72
|
)
|
|
73
|
+
from nextrec.utils.model import compute_ranking_loss
|
|
72
74
|
|
|
73
75
|
|
|
74
76
|
class BaseModel(FeatureSet, nn.Module):
|
|
@@ -88,13 +90,18 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
88
90
|
target: list[str] | str | None = None,
|
|
89
91
|
id_columns: list[str] | str | None = None,
|
|
90
92
|
task: str | list[str] | None = None,
|
|
93
|
+
training_mode: (
|
|
94
|
+
Literal["pointwise", "pairwise", "listwise"]
|
|
95
|
+
| list[Literal["pointwise", "pairwise", "listwise"]]
|
|
96
|
+
) = "pointwise",
|
|
91
97
|
embedding_l1_reg: float = 0.0,
|
|
92
98
|
dense_l1_reg: float = 0.0,
|
|
93
99
|
embedding_l2_reg: float = 0.0,
|
|
94
100
|
dense_l2_reg: float = 0.0,
|
|
95
101
|
device: str = "cpu",
|
|
96
102
|
early_stop_patience: int = 20,
|
|
97
|
-
|
|
103
|
+
early_stop_monitor_task: str | None = None,
|
|
104
|
+
metrics_sample_limit: int | None = 200000,
|
|
98
105
|
session_id: str | None = None,
|
|
99
106
|
callbacks: list[Callback] | None = None,
|
|
100
107
|
distributed: bool = False,
|
|
@@ -113,6 +120,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
113
120
|
target: Target column name. e.g., 'label' or ['label1', 'label2'].
|
|
114
121
|
id_columns: Identifier column name, only need to specify if GAUC is required. e.g., 'user_id'.
|
|
115
122
|
task: Task types, e.g., 'binary', 'regression', or ['binary', 'regression']. If None, falls back to self.default_task.
|
|
123
|
+
training_mode: Training mode for ranking tasks; a single mode or a list per task.
|
|
116
124
|
|
|
117
125
|
embedding_l1_reg: L1 regularization strength for embedding params. e.g., 1e-6.
|
|
118
126
|
dense_l1_reg: L1 regularization strength for dense params. e.g., 1e-5.
|
|
@@ -121,7 +129,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
121
129
|
|
|
122
130
|
device: Torch device string or torch.device. e.g., 'cpu', 'cuda:0'.
|
|
123
131
|
early_stop_patience: Epochs for early stopping. 0 to disable. e.g., 20.
|
|
124
|
-
|
|
132
|
+
early_stop_monitor_task: Task name to monitor for early stopping in multi-task scenario. If None, uses first target. e.g., 'click'.
|
|
133
|
+
metrics_sample_limit: Max samples to keep for training metrics. None disables limit.
|
|
125
134
|
session_id: Session id for logging. If None, a default id with timestamps will be created. e.g., 'session_tutorial'.
|
|
126
135
|
callbacks: List of callback instances. If None, default callbacks will be created. e.g., [EarlyStopper(), CheckpointSaver()].
|
|
127
136
|
|
|
@@ -150,9 +159,11 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
150
159
|
self.session = create_session(session_id)
|
|
151
160
|
self.session_path = self.session.root # pwd/session_id, path for this session
|
|
152
161
|
self.checkpoint_path = os.path.join(
|
|
153
|
-
self.session_path, self.model_name + "_checkpoint.pt"
|
|
162
|
+
self.session_path, self.model_name.upper() + "_checkpoint.pt"
|
|
154
163
|
) # e.g., pwd/session_id/DeepFM_checkpoint.pt
|
|
155
|
-
self.best_path = os.path.join(
|
|
164
|
+
self.best_path = os.path.join(
|
|
165
|
+
self.session_path, self.model_name.upper() + "_best.pt"
|
|
166
|
+
)
|
|
156
167
|
self.features_config_path = os.path.join(
|
|
157
168
|
self.session_path, "features_config.pkl"
|
|
158
169
|
)
|
|
@@ -162,6 +173,22 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
162
173
|
|
|
163
174
|
self.task = self.default_task if task is None else task
|
|
164
175
|
self.nums_task = len(self.task) if isinstance(self.task, list) else 1
|
|
176
|
+
if isinstance(training_mode, list):
|
|
177
|
+
if len(training_mode) != self.nums_task:
|
|
178
|
+
raise ValueError(
|
|
179
|
+
"[BaseModel-init Error] training_mode list length must match number of tasks."
|
|
180
|
+
)
|
|
181
|
+
self.training_modes = list(training_mode)
|
|
182
|
+
else:
|
|
183
|
+
self.training_modes = [training_mode] * self.nums_task
|
|
184
|
+
for mode in self.training_modes:
|
|
185
|
+
if mode not in {"pointwise", "pairwise", "listwise"}:
|
|
186
|
+
raise ValueError(
|
|
187
|
+
"[BaseModel-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
|
|
188
|
+
)
|
|
189
|
+
self.training_mode = (
|
|
190
|
+
self.training_modes if self.nums_task > 1 else self.training_modes[0]
|
|
191
|
+
)
|
|
165
192
|
|
|
166
193
|
self.embedding_l1_reg = embedding_l1_reg
|
|
167
194
|
self.dense_l1_reg = dense_l1_reg
|
|
@@ -172,9 +199,10 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
172
199
|
self.loss_weight = None
|
|
173
200
|
|
|
174
201
|
self.early_stop_patience = early_stop_patience
|
|
202
|
+
self.early_stop_monitor_task = early_stop_monitor_task
|
|
175
203
|
# max samples to keep for training metrics, in case of large training set
|
|
176
|
-
self.
|
|
177
|
-
None if
|
|
204
|
+
self.metrics_sample_limit = (
|
|
205
|
+
None if metrics_sample_limit is None else int(metrics_sample_limit)
|
|
178
206
|
)
|
|
179
207
|
self.max_gradient_norm = 1.0
|
|
180
208
|
self.logger_initialized = False
|
|
@@ -398,6 +426,33 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
398
426
|
Use "grad_norm" or {"method": "grad_norm", ...} to enable GradNorm for multi-task loss balancing.
|
|
399
427
|
callbacks: Additional callbacks to add to the existing callback list. e.g., [EarlyStopper(), CheckpointSaver()].
|
|
400
428
|
"""
|
|
429
|
+
default_losses = {
|
|
430
|
+
"pointwise": "bce",
|
|
431
|
+
"pairwise": "bpr",
|
|
432
|
+
"listwise": "listnet",
|
|
433
|
+
}
|
|
434
|
+
effective_loss = loss
|
|
435
|
+
if effective_loss is None:
|
|
436
|
+
loss_list = [default_losses[mode] for mode in self.training_modes]
|
|
437
|
+
elif isinstance(effective_loss, list):
|
|
438
|
+
if not effective_loss:
|
|
439
|
+
loss_list = [default_losses[mode] for mode in self.training_modes]
|
|
440
|
+
else:
|
|
441
|
+
if len(effective_loss) != self.nums_task:
|
|
442
|
+
raise ValueError(
|
|
443
|
+
f"[BaseModel-compile Error] Number of loss functions ({len(effective_loss)}) must match number of tasks ({self.nums_task})."
|
|
444
|
+
)
|
|
445
|
+
loss_list = list(effective_loss)
|
|
446
|
+
else:
|
|
447
|
+
loss_list = [effective_loss] * self.nums_task
|
|
448
|
+
|
|
449
|
+
for idx, mode in enumerate(self.training_modes):
|
|
450
|
+
if isinstance(loss_list[idx], str) and loss_list[idx] in {
|
|
451
|
+
"bce",
|
|
452
|
+
"binary_crossentropy",
|
|
453
|
+
}:
|
|
454
|
+
if mode in {"pairwise", "listwise"}:
|
|
455
|
+
loss_list[idx] = default_losses[mode]
|
|
401
456
|
if loss_params is None:
|
|
402
457
|
self.loss_params = {}
|
|
403
458
|
else:
|
|
@@ -427,16 +482,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
427
482
|
else None
|
|
428
483
|
)
|
|
429
484
|
|
|
430
|
-
self.loss_config =
|
|
485
|
+
self.loss_config = loss_list if self.nums_task > 1 else loss_list[0]
|
|
431
486
|
self.loss_params = loss_params or {}
|
|
432
|
-
if isinstance(loss, list):
|
|
433
|
-
if len(loss) != self.nums_task:
|
|
434
|
-
raise ValueError(
|
|
435
|
-
f"[BaseModel-compile Error] Number of loss functions ({len(loss)}) must match number of tasks ({self.nums_task})."
|
|
436
|
-
)
|
|
437
|
-
loss_list = list(loss)
|
|
438
|
-
else:
|
|
439
|
-
loss_list = [loss] * self.nums_task
|
|
440
487
|
if isinstance(self.loss_params, dict):
|
|
441
488
|
loss_params_list = [self.loss_params] * self.nums_task
|
|
442
489
|
else:
|
|
@@ -457,7 +504,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
457
504
|
"[BaseModel-compile Error] GradNorm requires multi-task setup."
|
|
458
505
|
)
|
|
459
506
|
self.grad_norm = GradNormLossWeighting(
|
|
460
|
-
|
|
507
|
+
nums_task=self.nums_task, device=self.device
|
|
461
508
|
)
|
|
462
509
|
self.loss_weights = None
|
|
463
510
|
elif (
|
|
@@ -470,7 +517,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
470
517
|
grad_norm_params = dict(loss_weights)
|
|
471
518
|
grad_norm_params.pop("method", None)
|
|
472
519
|
self.grad_norm = GradNormLossWeighting(
|
|
473
|
-
|
|
520
|
+
nums_task=self.nums_task, device=self.device, **grad_norm_params
|
|
474
521
|
)
|
|
475
522
|
self.loss_weights = None
|
|
476
523
|
elif loss_weights is None:
|
|
@@ -508,6 +555,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
508
555
|
raise ValueError(
|
|
509
556
|
"[BaseModel-compute_loss Error] Ground truth labels (y_true) are required."
|
|
510
557
|
)
|
|
558
|
+
# single-task
|
|
511
559
|
if self.nums_task == 1:
|
|
512
560
|
if y_pred.dim() == 1:
|
|
513
561
|
y_pred = y_pred.view(-1, 1)
|
|
@@ -515,16 +563,30 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
515
563
|
y_true = y_true.view(-1, 1)
|
|
516
564
|
if y_pred.shape != y_true.shape:
|
|
517
565
|
raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
|
|
566
|
+
loss_fn = self.loss_fn[0] if getattr(self, "loss_fn", None) else None
|
|
567
|
+
if loss_fn is None:
|
|
568
|
+
raise ValueError(
|
|
569
|
+
"[BaseModel-compute_loss Error] Loss function is not configured. Call compile() first."
|
|
570
|
+
)
|
|
571
|
+
mode = self.training_modes[0]
|
|
518
572
|
task_dim = (
|
|
519
573
|
self.task_dims[0] if hasattr(self, "task_dims") else y_pred.shape[1] # type: ignore
|
|
520
574
|
)
|
|
521
|
-
if
|
|
522
|
-
loss =
|
|
575
|
+
if mode in {"pairwise", "listwise"}:
|
|
576
|
+
loss = compute_ranking_loss(
|
|
577
|
+
training_mode=mode,
|
|
578
|
+
loss_fn=loss_fn,
|
|
579
|
+
y_pred=y_pred,
|
|
580
|
+
y_true=y_true,
|
|
581
|
+
)
|
|
582
|
+
elif task_dim == 1:
|
|
583
|
+
loss = loss_fn(y_pred.view(-1), y_true.view(-1))
|
|
523
584
|
else:
|
|
524
|
-
loss =
|
|
585
|
+
loss = loss_fn(y_pred, y_true)
|
|
525
586
|
if self.loss_weights is not None:
|
|
526
587
|
loss *= self.loss_weights[0]
|
|
527
588
|
return loss
|
|
589
|
+
|
|
528
590
|
# multi-task
|
|
529
591
|
if y_pred.shape != y_true.shape:
|
|
530
592
|
raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
|
|
@@ -537,7 +599,16 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
537
599
|
for i, (start, end) in enumerate(slices): # type: ignore
|
|
538
600
|
y_pred_i = y_pred[:, start:end]
|
|
539
601
|
y_true_i = y_true[:, start:end]
|
|
540
|
-
|
|
602
|
+
mode = self.training_modes[i]
|
|
603
|
+
if mode in {"pairwise", "listwise"}:
|
|
604
|
+
task_loss = compute_ranking_loss(
|
|
605
|
+
training_mode=mode,
|
|
606
|
+
loss_fn=self.loss_fn[i],
|
|
607
|
+
y_pred=y_pred_i,
|
|
608
|
+
y_true=y_true_i,
|
|
609
|
+
)
|
|
610
|
+
else:
|
|
611
|
+
task_loss = self.loss_fn[i](y_pred_i, y_true_i)
|
|
541
612
|
task_losses.append(task_loss)
|
|
542
613
|
if self.grad_norm is not None:
|
|
543
614
|
if self.grad_norm_shared_params is None:
|
|
@@ -603,8 +674,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
603
674
|
user_id_column: str | None = None,
|
|
604
675
|
validation_split: float | None = None,
|
|
605
676
|
num_workers: int = 0,
|
|
606
|
-
|
|
607
|
-
|
|
677
|
+
use_tensorboard: bool = True,
|
|
678
|
+
auto_ddp_sampler: bool = True,
|
|
608
679
|
log_interval: int = 1,
|
|
609
680
|
):
|
|
610
681
|
"""
|
|
@@ -620,8 +691,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
620
691
|
user_id_column: Column name for GAUC-style metrics;.
|
|
621
692
|
validation_split: Ratio to split training data when valid_data is None.
|
|
622
693
|
num_workers: DataLoader worker count.
|
|
623
|
-
|
|
624
|
-
|
|
694
|
+
use_tensorboard: Enable tensorboard logging.
|
|
695
|
+
auto_ddp_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
|
|
625
696
|
log_interval: Log validation metrics every N epochs (still computes metrics each epoch).
|
|
626
697
|
|
|
627
698
|
Notes:
|
|
@@ -663,7 +734,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
663
734
|
setup_logger(session_id=self.session_id)
|
|
664
735
|
self.logger_initialized = True
|
|
665
736
|
self.training_logger = (
|
|
666
|
-
TrainingLogger(session=self.session,
|
|
737
|
+
TrainingLogger(session=self.session, use_tensorboard=use_tensorboard)
|
|
667
738
|
if self.is_main_process
|
|
668
739
|
else None
|
|
669
740
|
)
|
|
@@ -681,18 +752,21 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
681
752
|
if self.nums_task == 1:
|
|
682
753
|
monitor_metric = f"val_{self.metrics[0]}"
|
|
683
754
|
else:
|
|
684
|
-
|
|
755
|
+
# Determine which task to monitor for early stopping
|
|
756
|
+
monitor_task = self.early_stop_monitor_task
|
|
757
|
+
if monitor_task is None:
|
|
758
|
+
monitor_task = self.target_columns[0]
|
|
759
|
+
elif monitor_task not in self.target_columns:
|
|
760
|
+
raise ValueError(
|
|
761
|
+
f"[BaseModel-fit Error] early_stop_monitor_task '{monitor_task}' not found in target_columns {self.target_columns}."
|
|
762
|
+
)
|
|
763
|
+
monitor_metric = f"val_{self.metrics[0]}_{monitor_task}"
|
|
685
764
|
|
|
686
765
|
existing_callbacks = self.callbacks.callbacks
|
|
687
|
-
has_early_stop = any(isinstance(cb, EarlyStopper) for cb in existing_callbacks)
|
|
688
|
-
has_checkpoint = any(
|
|
689
|
-
isinstance(cb, CheckpointSaver) for cb in existing_callbacks
|
|
690
|
-
)
|
|
691
|
-
has_lr_scheduler = any(
|
|
692
|
-
isinstance(cb, LearningRateScheduler) for cb in existing_callbacks
|
|
693
|
-
)
|
|
694
766
|
|
|
695
|
-
if self.early_stop_patience > 0 and not
|
|
767
|
+
if self.early_stop_patience > 0 and not any(
|
|
768
|
+
isinstance(cb, EarlyStopper) for cb in existing_callbacks
|
|
769
|
+
):
|
|
696
770
|
self.callbacks.append(
|
|
697
771
|
EarlyStopper(
|
|
698
772
|
monitor=monitor_metric,
|
|
@@ -703,7 +777,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
703
777
|
)
|
|
704
778
|
)
|
|
705
779
|
|
|
706
|
-
if self.is_main_process and not
|
|
780
|
+
if self.is_main_process and not any(
|
|
781
|
+
isinstance(cb, CheckpointSaver) for cb in existing_callbacks
|
|
782
|
+
):
|
|
707
783
|
self.callbacks.append(
|
|
708
784
|
CheckpointSaver(
|
|
709
785
|
best_path=self.best_path,
|
|
@@ -715,7 +791,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
715
791
|
)
|
|
716
792
|
)
|
|
717
793
|
|
|
718
|
-
if self.scheduler_fn is not None and not
|
|
794
|
+
if self.scheduler_fn is not None and not any(
|
|
795
|
+
isinstance(cb, LearningRateScheduler) for cb in existing_callbacks
|
|
796
|
+
):
|
|
719
797
|
self.callbacks.append(
|
|
720
798
|
LearningRateScheduler(
|
|
721
799
|
scheduler=self.scheduler_fn,
|
|
@@ -738,16 +816,16 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
738
816
|
self.stop_training = False
|
|
739
817
|
self.best_checkpoint_path = self.best_path
|
|
740
818
|
use_ddp_sampler = (
|
|
741
|
-
|
|
819
|
+
auto_ddp_sampler
|
|
742
820
|
and self.distributed
|
|
743
821
|
and dist.is_available()
|
|
744
822
|
and dist.is_initialized()
|
|
745
823
|
)
|
|
746
824
|
|
|
747
|
-
if not
|
|
825
|
+
if not auto_ddp_sampler and self.distributed and self.is_main_process:
|
|
748
826
|
logging.info(
|
|
749
827
|
colorize(
|
|
750
|
-
"[Distributed Info]
|
|
828
|
+
"[Distributed Info] auto_ddp_sampler=False; assuming data is already sharded per rank.",
|
|
751
829
|
color="yellow",
|
|
752
830
|
)
|
|
753
831
|
)
|
|
@@ -826,12 +904,12 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
826
904
|
# If split-based loader was built without sampler, attach here when enabled
|
|
827
905
|
if (
|
|
828
906
|
self.distributed
|
|
829
|
-
and
|
|
907
|
+
and auto_ddp_sampler
|
|
830
908
|
and isinstance(train_loader, DataLoader)
|
|
831
909
|
and train_sampler is None
|
|
832
910
|
):
|
|
833
911
|
raise NotImplementedError(
|
|
834
|
-
"[BaseModel-fit Error]
|
|
912
|
+
"[BaseModel-fit Error] auto_ddp_sampler with pre-defined DataLoader is not supported yet."
|
|
835
913
|
)
|
|
836
914
|
# train_loader, train_sampler = add_distributed_sampler(train_loader, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
|
|
837
915
|
|
|
@@ -841,7 +919,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
841
919
|
needs_user_ids=self.needs_user_ids,
|
|
842
920
|
user_id_column=user_id_column,
|
|
843
921
|
num_workers=num_workers,
|
|
844
|
-
|
|
922
|
+
auto_ddp_sampler=auto_ddp_sampler,
|
|
845
923
|
)
|
|
846
924
|
try:
|
|
847
925
|
self.steps_per_epoch = len(train_loader)
|
|
@@ -863,7 +941,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
863
941
|
logging.info("")
|
|
864
942
|
tb_dir = (
|
|
865
943
|
self.training_logger.tensorboard_logdir
|
|
866
|
-
if self.training_logger and self.training_logger.
|
|
944
|
+
if self.training_logger and self.training_logger.use_tensorboard
|
|
867
945
|
else None
|
|
868
946
|
)
|
|
869
947
|
if tb_dir:
|
|
@@ -1055,7 +1133,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1055
1133
|
y_true_list = []
|
|
1056
1134
|
y_pred_list = []
|
|
1057
1135
|
collect_metrics = getattr(self, "collect_train_metrics", True)
|
|
1058
|
-
max_samples = getattr(self, "
|
|
1136
|
+
max_samples = getattr(self, "metrics_sample_limit", None)
|
|
1059
1137
|
collected_samples = 0
|
|
1060
1138
|
metrics_capped = False
|
|
1061
1139
|
|
|
@@ -1184,14 +1262,14 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1184
1262
|
needs_user_ids: bool,
|
|
1185
1263
|
user_id_column: str | None = "user_id",
|
|
1186
1264
|
num_workers: int = 0,
|
|
1187
|
-
|
|
1265
|
+
auto_ddp_sampler: bool = True,
|
|
1188
1266
|
) -> tuple[DataLoader | None, np.ndarray | None]:
|
|
1189
1267
|
if valid_data is None:
|
|
1190
1268
|
return None, None
|
|
1191
1269
|
if isinstance(valid_data, DataLoader):
|
|
1192
|
-
if
|
|
1270
|
+
if auto_ddp_sampler and self.distributed:
|
|
1193
1271
|
raise NotImplementedError(
|
|
1194
|
-
"[BaseModel-prepare_validation_data Error]
|
|
1272
|
+
"[BaseModel-prepare_validation_data Error] auto_ddp_sampler with pre-defined DataLoader is not supported yet."
|
|
1195
1273
|
)
|
|
1196
1274
|
# valid_loader, _ = add_distributed_sampler(valid_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=False, drop_last=False, default_batch_size=batch_size, is_main_process=self.is_main_process)
|
|
1197
1275
|
else:
|
|
@@ -1200,7 +1278,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1200
1278
|
valid_sampler = None
|
|
1201
1279
|
valid_loader, valid_dataset = self.prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, return_dataset=True) # type: ignore
|
|
1202
1280
|
if (
|
|
1203
|
-
|
|
1281
|
+
auto_ddp_sampler
|
|
1204
1282
|
and self.distributed
|
|
1205
1283
|
and valid_dataset is not None
|
|
1206
1284
|
and dist.is_available()
|
|
@@ -1373,11 +1451,11 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1373
1451
|
data: str | dict | pd.DataFrame | DataLoader,
|
|
1374
1452
|
batch_size: int = 32,
|
|
1375
1453
|
save_path: str | os.PathLike | None = None,
|
|
1376
|
-
save_format:
|
|
1454
|
+
save_format: str = "csv",
|
|
1377
1455
|
include_ids: bool | None = None,
|
|
1378
1456
|
id_columns: str | list[str] | None = None,
|
|
1379
1457
|
return_dataframe: bool = True,
|
|
1380
|
-
|
|
1458
|
+
stream_chunk_size: int = 10000,
|
|
1381
1459
|
num_workers: int = 0,
|
|
1382
1460
|
) -> pd.DataFrame | np.ndarray | Path | None:
|
|
1383
1461
|
"""
|
|
@@ -1392,7 +1470,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1392
1470
|
include_ids: Whether to include ID columns in the output; if None, includes if id_columns are set.
|
|
1393
1471
|
id_columns: Column name(s) to use as IDs; if None, uses model's id_columns.
|
|
1394
1472
|
return_dataframe: Whether to return predictions as a pandas DataFrame; if False, returns a NumPy array.
|
|
1395
|
-
|
|
1473
|
+
stream_chunk_size: Number of rows per chunk when using streaming mode for large datasets.
|
|
1396
1474
|
num_workers: DataLoader worker count.
|
|
1397
1475
|
"""
|
|
1398
1476
|
self.eval()
|
|
@@ -1413,7 +1491,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1413
1491
|
save_path=save_path,
|
|
1414
1492
|
save_format=save_format,
|
|
1415
1493
|
include_ids=include_ids,
|
|
1416
|
-
|
|
1494
|
+
stream_chunk_size=stream_chunk_size,
|
|
1417
1495
|
return_dataframe=return_dataframe,
|
|
1418
1496
|
id_columns=predict_id_columns,
|
|
1419
1497
|
)
|
|
@@ -1439,7 +1517,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1439
1517
|
batch_size=batch_size,
|
|
1440
1518
|
shuffle=False,
|
|
1441
1519
|
streaming=True,
|
|
1442
|
-
chunk_size=
|
|
1520
|
+
chunk_size=stream_chunk_size,
|
|
1443
1521
|
)
|
|
1444
1522
|
else:
|
|
1445
1523
|
data_loader = self.prepare_data_loader(
|
|
@@ -1517,11 +1595,18 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1517
1595
|
else y_pred_all
|
|
1518
1596
|
)
|
|
1519
1597
|
if save_path is not None:
|
|
1520
|
-
|
|
1521
|
-
|
|
1522
|
-
|
|
1598
|
+
# Check streaming write support
|
|
1599
|
+
if not check_streaming_support(save_format):
|
|
1600
|
+
logging.warning(
|
|
1601
|
+
f"[BaseModel-predict Warning] Format '{save_format}' does not support streaming writes. "
|
|
1602
|
+
"The entire result will be saved at once. Use csv or parquet for large datasets."
|
|
1523
1603
|
)
|
|
1524
|
-
|
|
1604
|
+
|
|
1605
|
+
# Get file extension from format
|
|
1606
|
+
from nextrec.utils.data import FILE_FORMAT_CONFIG
|
|
1607
|
+
|
|
1608
|
+
suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
|
|
1609
|
+
|
|
1525
1610
|
target_path = resolve_save_path(
|
|
1526
1611
|
path=save_path,
|
|
1527
1612
|
default_dir=self.session.predictions_dir,
|
|
@@ -1540,10 +1625,21 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1540
1625
|
f"[BaseModel-predict Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_to_save)})."
|
|
1541
1626
|
)
|
|
1542
1627
|
df_to_save = pd.concat([id_df, df_to_save], axis=1)
|
|
1628
|
+
|
|
1629
|
+
# Save based on format
|
|
1543
1630
|
if save_format == "csv":
|
|
1544
1631
|
df_to_save.to_csv(target_path, index=False)
|
|
1545
|
-
|
|
1632
|
+
elif save_format == "parquet":
|
|
1546
1633
|
df_to_save.to_parquet(target_path, index=False)
|
|
1634
|
+
elif save_format == "feather":
|
|
1635
|
+
df_to_save.to_feather(target_path)
|
|
1636
|
+
elif save_format == "excel":
|
|
1637
|
+
df_to_save.to_excel(target_path, index=False)
|
|
1638
|
+
elif save_format == "hdf5":
|
|
1639
|
+
df_to_save.to_hdf(target_path, key="predictions", mode="w")
|
|
1640
|
+
else:
|
|
1641
|
+
raise ValueError(f"Unsupported save format: {save_format}")
|
|
1642
|
+
|
|
1547
1643
|
logging.info(
|
|
1548
1644
|
colorize(f"Predictions saved to: {target_path}", color="green")
|
|
1549
1645
|
)
|
|
@@ -1554,9 +1650,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1554
1650
|
data: str | dict | pd.DataFrame | DataLoader,
|
|
1555
1651
|
batch_size: int,
|
|
1556
1652
|
save_path: str | os.PathLike,
|
|
1557
|
-
save_format:
|
|
1653
|
+
save_format: str,
|
|
1558
1654
|
include_ids: bool,
|
|
1559
|
-
|
|
1655
|
+
stream_chunk_size: int,
|
|
1560
1656
|
return_dataframe: bool,
|
|
1561
1657
|
id_columns: list[str] | None = None,
|
|
1562
1658
|
) -> pd.DataFrame | Path:
|
|
@@ -1573,7 +1669,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1573
1669
|
batch_size=batch_size,
|
|
1574
1670
|
shuffle=False,
|
|
1575
1671
|
streaming=True,
|
|
1576
|
-
chunk_size=
|
|
1672
|
+
chunk_size=stream_chunk_size,
|
|
1577
1673
|
)
|
|
1578
1674
|
elif not isinstance(data, DataLoader):
|
|
1579
1675
|
data_loader = self.prepare_data_loader(
|
|
@@ -1595,7 +1691,17 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1595
1691
|
"When using streaming mode, set num_workers=0 to avoid reading data multiple times."
|
|
1596
1692
|
)
|
|
1597
1693
|
|
|
1598
|
-
|
|
1694
|
+
# Check streaming support and prepare file path
|
|
1695
|
+
if not check_streaming_support(save_format):
|
|
1696
|
+
logging.warning(
|
|
1697
|
+
f"[Predict Streaming Warning] Format '{save_format}' does not support streaming writes. "
|
|
1698
|
+
"Results will be collected in memory and saved at the end. Use csv or parquet for true streaming."
|
|
1699
|
+
)
|
|
1700
|
+
|
|
1701
|
+
from nextrec.utils.data import FILE_FORMAT_CONFIG
|
|
1702
|
+
|
|
1703
|
+
suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
|
|
1704
|
+
|
|
1599
1705
|
target_path = resolve_save_path(
|
|
1600
1706
|
path=save_path,
|
|
1601
1707
|
default_dir=self.session.predictions_dir,
|
|
@@ -1606,9 +1712,10 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1606
1712
|
target_path.parent.mkdir(parents=True, exist_ok=True)
|
|
1607
1713
|
header_written = target_path.exists() and target_path.stat().st_size > 0
|
|
1608
1714
|
parquet_writer = None
|
|
1609
|
-
|
|
1610
1715
|
pred_columns = None
|
|
1611
|
-
collected_frames =
|
|
1716
|
+
collected_frames = (
|
|
1717
|
+
[]
|
|
1718
|
+
) # used when return_dataframe=True or for non-streaming formats
|
|
1612
1719
|
|
|
1613
1720
|
with torch.no_grad():
|
|
1614
1721
|
for batch_data in progress(data_loader, description="Predicting"):
|
|
@@ -1650,27 +1757,48 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1650
1757
|
)
|
|
1651
1758
|
df_batch = pd.concat([id_df, df_batch], axis=1)
|
|
1652
1759
|
|
|
1760
|
+
# Streaming save based on format
|
|
1653
1761
|
if save_format == "csv":
|
|
1654
1762
|
df_batch.to_csv(
|
|
1655
1763
|
target_path, mode="a", header=not header_written, index=False
|
|
1656
1764
|
)
|
|
1657
1765
|
header_written = True
|
|
1658
|
-
|
|
1766
|
+
elif save_format == "parquet":
|
|
1659
1767
|
try:
|
|
1660
1768
|
import pyarrow as pa
|
|
1661
1769
|
import pyarrow.parquet as pq
|
|
1662
1770
|
except ImportError as exc: # pragma: no cover
|
|
1663
1771
|
raise ImportError(
|
|
1664
|
-
"[BaseModel-predict-streaming Error] Parquet streaming save requires pyarrow
|
|
1772
|
+
"[BaseModel-predict-streaming Error] Parquet streaming save requires pyarrow."
|
|
1665
1773
|
) from exc
|
|
1666
1774
|
table = pa.Table.from_pandas(df_batch, preserve_index=False)
|
|
1667
1775
|
if parquet_writer is None:
|
|
1668
1776
|
parquet_writer = pq.ParquetWriter(target_path, table.schema)
|
|
1669
1777
|
parquet_writer.write_table(table)
|
|
1670
|
-
|
|
1778
|
+
else:
|
|
1779
|
+
# Non-streaming formats: collect all data
|
|
1671
1780
|
collected_frames.append(df_batch)
|
|
1781
|
+
|
|
1782
|
+
if return_dataframe:
|
|
1783
|
+
if (
|
|
1784
|
+
save_format in ["csv", "parquet"]
|
|
1785
|
+
and df_batch not in collected_frames
|
|
1786
|
+
):
|
|
1787
|
+
collected_frames.append(df_batch)
|
|
1788
|
+
|
|
1789
|
+
# Close writers
|
|
1672
1790
|
if parquet_writer is not None:
|
|
1673
1791
|
parquet_writer.close()
|
|
1792
|
+
# For non-streaming formats, save collected data
|
|
1793
|
+
if save_format in ["feather", "excel", "hdf5"] and collected_frames:
|
|
1794
|
+
combined_df = pd.concat(collected_frames, ignore_index=True)
|
|
1795
|
+
if save_format == "feather":
|
|
1796
|
+
combined_df.to_feather(target_path)
|
|
1797
|
+
elif save_format == "excel":
|
|
1798
|
+
combined_df.to_excel(target_path, index=False)
|
|
1799
|
+
elif save_format == "hdf5":
|
|
1800
|
+
combined_df.to_hdf(target_path, key="predictions", mode="w")
|
|
1801
|
+
|
|
1674
1802
|
logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
|
|
1675
1803
|
if return_dataframe:
|
|
1676
1804
|
return (
|
|
@@ -1691,7 +1819,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1691
1819
|
target_path = resolve_save_path(
|
|
1692
1820
|
path=save_path,
|
|
1693
1821
|
default_dir=self.session_path,
|
|
1694
|
-
default_name=self.model_name,
|
|
1822
|
+
default_name=self.model_name.upper(),
|
|
1695
1823
|
suffix=".pt",
|
|
1696
1824
|
add_timestamp=add_timestamp,
|
|
1697
1825
|
)
|
|
@@ -1845,7 +1973,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1845
1973
|
logger.info("")
|
|
1846
1974
|
logger.info(
|
|
1847
1975
|
colorize(
|
|
1848
|
-
f"Model Summary: {self.model_name}",
|
|
1976
|
+
f"Model Summary: {self.model_name.upper()}",
|
|
1977
|
+
color="bright_blue",
|
|
1978
|
+
bold=True,
|
|
1849
1979
|
)
|
|
1850
1980
|
)
|
|
1851
1981
|
logger.info("")
|
|
@@ -1976,7 +2106,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1976
2106
|
logger.info("Other Settings:")
|
|
1977
2107
|
logger.info(f" Early Stop Patience: {self.early_stop_patience}")
|
|
1978
2108
|
logger.info(f" Max Gradient Norm: {self.max_gradient_norm}")
|
|
1979
|
-
logger.info(f" Max Metrics Samples: {self.
|
|
2109
|
+
logger.info(f" Max Metrics Samples: {self.metrics_sample_limit}")
|
|
1980
2110
|
logger.info(f" Session ID: {self.session_id}")
|
|
1981
2111
|
logger.info(f" Features Config Path: {self.features_config_path}")
|
|
1982
2112
|
logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
|
|
@@ -2146,7 +2276,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2146
2276
|
"""
|
|
2147
2277
|
if self.training_mode not in self.support_training_modes:
|
|
2148
2278
|
raise ValueError(
|
|
2149
|
-
f"{self.model_name} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
|
|
2279
|
+
f"{self.model_name.upper()} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
|
|
2150
2280
|
)
|
|
2151
2281
|
|
|
2152
2282
|
default_loss_by_mode: dict[str, str] = {
|
|
@@ -2251,9 +2381,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2251
2381
|
user_emb = self.user_tower(user_input) # [B, D]
|
|
2252
2382
|
item_emb = self.item_tower(item_input) # [B, D]
|
|
2253
2383
|
|
|
2254
|
-
return self.head(
|
|
2255
|
-
user_emb, item_emb, similarity_fn=self.compute_similarity
|
|
2256
|
-
)
|
|
2384
|
+
return self.head(user_emb, item_emb, similarity_fn=self.compute_similarity)
|
|
2257
2385
|
|
|
2258
2386
|
def compute_loss(self, y_pred, y_true):
|
|
2259
2387
|
if self.training_mode == "pointwise":
|
|
@@ -2309,7 +2437,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2309
2437
|
features: list,
|
|
2310
2438
|
batch_size: int,
|
|
2311
2439
|
num_workers: int = 0,
|
|
2312
|
-
|
|
2440
|
+
stream_chunk_size: int = 10000,
|
|
2313
2441
|
) -> DataLoader:
|
|
2314
2442
|
"""Prepare data loader for specific features."""
|
|
2315
2443
|
if isinstance(data, DataLoader):
|
|
@@ -2330,7 +2458,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2330
2458
|
batch_size=batch_size,
|
|
2331
2459
|
shuffle=False,
|
|
2332
2460
|
streaming=True,
|
|
2333
|
-
chunk_size=
|
|
2461
|
+
chunk_size=stream_chunk_size,
|
|
2334
2462
|
num_workers=num_workers,
|
|
2335
2463
|
)
|
|
2336
2464
|
tensors = build_tensors_from_data(
|
|
@@ -2383,7 +2511,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2383
2511
|
),
|
|
2384
2512
|
batch_size: int = 512,
|
|
2385
2513
|
num_workers: int = 0,
|
|
2386
|
-
|
|
2514
|
+
stream_chunk_size: int = 10000,
|
|
2387
2515
|
) -> np.ndarray:
|
|
2388
2516
|
self.eval()
|
|
2389
2517
|
data_loader = self.prepare_feature_data(
|
|
@@ -2391,7 +2519,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2391
2519
|
self.user_features_all,
|
|
2392
2520
|
batch_size,
|
|
2393
2521
|
num_workers=num_workers,
|
|
2394
|
-
|
|
2522
|
+
stream_chunk_size=stream_chunk_size,
|
|
2395
2523
|
)
|
|
2396
2524
|
|
|
2397
2525
|
embeddings_list = []
|
|
@@ -2417,7 +2545,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2417
2545
|
),
|
|
2418
2546
|
batch_size: int = 512,
|
|
2419
2547
|
num_workers: int = 0,
|
|
2420
|
-
|
|
2548
|
+
stream_chunk_size: int = 10000,
|
|
2421
2549
|
) -> np.ndarray:
|
|
2422
2550
|
self.eval()
|
|
2423
2551
|
data_loader = self.prepare_feature_data(
|
|
@@ -2425,7 +2553,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2425
2553
|
self.item_features_all,
|
|
2426
2554
|
batch_size,
|
|
2427
2555
|
num_workers=num_workers,
|
|
2428
|
-
|
|
2556
|
+
stream_chunk_size=stream_chunk_size,
|
|
2429
2557
|
)
|
|
2430
2558
|
|
|
2431
2559
|
embeddings_list = []
|