nextrec 0.4.24__py3-none-any.whl → 0.4.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/asserts.py +72 -0
- nextrec/basic/loggers.py +18 -1
- nextrec/basic/model.py +191 -71
- nextrec/basic/summary.py +58 -0
- nextrec/cli.py +13 -0
- nextrec/data/data_processing.py +3 -9
- nextrec/data/dataloader.py +25 -2
- nextrec/data/preprocessor.py +283 -36
- nextrec/models/multi_task/[pre]aitm.py +173 -0
- nextrec/models/multi_task/[pre]snr_trans.py +232 -0
- nextrec/models/multi_task/[pre]star.py +192 -0
- nextrec/models/multi_task/apg.py +330 -0
- nextrec/models/multi_task/cross_stitch.py +229 -0
- nextrec/models/multi_task/escm.py +290 -0
- nextrec/models/multi_task/esmm.py +8 -21
- nextrec/models/multi_task/hmoe.py +203 -0
- nextrec/models/multi_task/mmoe.py +20 -28
- nextrec/models/multi_task/pepnet.py +68 -66
- nextrec/models/multi_task/ple.py +30 -44
- nextrec/models/multi_task/poso.py +13 -22
- nextrec/models/multi_task/share_bottom.py +14 -25
- nextrec/models/ranking/afm.py +2 -2
- nextrec/models/ranking/autoint.py +2 -4
- nextrec/models/ranking/dcn.py +2 -3
- nextrec/models/ranking/dcn_v2.py +2 -3
- nextrec/models/ranking/deepfm.py +2 -3
- nextrec/models/ranking/dien.py +7 -9
- nextrec/models/ranking/din.py +8 -10
- nextrec/models/ranking/eulernet.py +1 -2
- nextrec/models/ranking/ffm.py +1 -2
- nextrec/models/ranking/fibinet.py +2 -3
- nextrec/models/ranking/fm.py +1 -1
- nextrec/models/ranking/lr.py +1 -1
- nextrec/models/ranking/masknet.py +1 -2
- nextrec/models/ranking/pnn.py +1 -2
- nextrec/models/ranking/widedeep.py +2 -3
- nextrec/models/ranking/xdeepfm.py +2 -4
- nextrec/models/representation/rqvae.py +4 -4
- nextrec/models/retrieval/dssm.py +18 -26
- nextrec/models/retrieval/dssm_v2.py +15 -22
- nextrec/models/retrieval/mind.py +9 -15
- nextrec/models/retrieval/sdm.py +36 -33
- nextrec/models/retrieval/youtube_dnn.py +16 -24
- nextrec/models/sequential/hstu.py +2 -2
- nextrec/utils/__init__.py +5 -1
- nextrec/utils/config.py +2 -0
- nextrec/utils/model.py +16 -77
- nextrec/utils/torch_utils.py +11 -0
- {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/METADATA +72 -62
- nextrec-0.4.27.dist-info/RECORD +90 -0
- nextrec/models/multi_task/aitm.py +0 -0
- nextrec/models/multi_task/snr_trans.py +0 -0
- nextrec-0.4.24.dist-info/RECORD +0 -86
- {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/WHEEL +0 -0
- {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 01/01/2026
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -36,6 +36,7 @@ from torch.utils.data import DataLoader
|
|
|
36
36
|
from torch.utils.data.distributed import DistributedSampler
|
|
37
37
|
|
|
38
38
|
from nextrec import __version__
|
|
39
|
+
from nextrec.basic.asserts import assert_task
|
|
39
40
|
from nextrec.basic.callback import (
|
|
40
41
|
CallbackList,
|
|
41
42
|
CheckpointSaver,
|
|
@@ -88,9 +89,8 @@ from nextrec.utils.config import safe_value
|
|
|
88
89
|
from nextrec.utils.model import (
|
|
89
90
|
compute_ranking_loss,
|
|
90
91
|
get_loss_list,
|
|
91
|
-
resolve_loss_weights,
|
|
92
|
-
get_training_modes,
|
|
93
92
|
)
|
|
93
|
+
|
|
94
94
|
from nextrec.utils.types import (
|
|
95
95
|
LossName,
|
|
96
96
|
OptimizerName,
|
|
@@ -100,6 +100,8 @@ from nextrec.utils.types import (
|
|
|
100
100
|
MetricsName,
|
|
101
101
|
)
|
|
102
102
|
|
|
103
|
+
from nextrec.utils.data import FILE_FORMAT_CONFIG
|
|
104
|
+
|
|
103
105
|
|
|
104
106
|
class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
105
107
|
@property
|
|
@@ -118,7 +120,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
118
120
|
target: list[str] | str | None = None,
|
|
119
121
|
id_columns: list[str] | str | None = None,
|
|
120
122
|
task: TaskTypeName | list[TaskTypeName] | None = None,
|
|
121
|
-
training_mode: TrainingModeName | list[TrainingModeName] =
|
|
123
|
+
training_mode: TrainingModeName | list[TrainingModeName] | None = None,
|
|
122
124
|
embedding_l1_reg: float = 0.0,
|
|
123
125
|
dense_l1_reg: float = 0.0,
|
|
124
126
|
embedding_l2_reg: float = 0.0,
|
|
@@ -138,10 +140,10 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
138
140
|
dense_features: DenseFeature definitions.
|
|
139
141
|
sparse_features: SparseFeature definitions.
|
|
140
142
|
sequence_features: SequenceFeature definitions.
|
|
141
|
-
target: Target column name. e.g., '
|
|
143
|
+
target: Target column name. e.g., 'label_ctr' or ['label_ctr', 'label_cvr'].
|
|
142
144
|
id_columns: Identifier column name, only need to specify if GAUC is required. e.g., 'user_id'.
|
|
143
145
|
task: Task types, e.g., 'binary', 'regression', or ['binary', 'regression']. If None, falls back to self.default_task.
|
|
144
|
-
training_mode: Training mode for
|
|
146
|
+
training_mode: Training mode for different tasks. e.g., 'pointwise', ['pointwise', 'pairwise'].
|
|
145
147
|
|
|
146
148
|
embedding_l1_reg: L1 regularization strength for embedding params. e.g., 1e-6.
|
|
147
149
|
dense_l1_reg: L1 regularization strength for dense params. e.g., 1e-5.
|
|
@@ -193,10 +195,12 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
193
195
|
|
|
194
196
|
self.task = task or self.default_task
|
|
195
197
|
self.nums_task = len(self.task) if isinstance(self.task, list) else 1
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
198
|
+
|
|
199
|
+
training_mode = training_mode or "pointwise"
|
|
200
|
+
if isinstance(training_mode, list):
|
|
201
|
+
self.training_modes = list(training_mode)
|
|
202
|
+
else:
|
|
203
|
+
self.training_modes = [training_mode] * self.nums_task
|
|
200
204
|
|
|
201
205
|
self.embedding_l1_reg = embedding_l1_reg
|
|
202
206
|
self.dense_l1_reg = dense_l1_reg
|
|
@@ -215,6 +219,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
215
219
|
|
|
216
220
|
self.train_data_summary = None
|
|
217
221
|
self.valid_data_summary = None
|
|
222
|
+
self.note = None
|
|
218
223
|
|
|
219
224
|
def register_regularization_weights(
|
|
220
225
|
self,
|
|
@@ -222,6 +227,15 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
222
227
|
exclude_modules: list[str] | None = None,
|
|
223
228
|
include_modules: list[str] | None = None,
|
|
224
229
|
):
|
|
230
|
+
"""
|
|
231
|
+
Register parameters for regularization.
|
|
232
|
+
By default, all nn.Linear weights (excluding those in BatchNorm/Dropout layers) and embedding weights under `embedding_attr` are registered.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
embedding_attr: Attribute name of the embedding layer/module.
|
|
236
|
+
exclude_modules: List of module name substrings to exclude from regularization.
|
|
237
|
+
include_modules: List of module name substrings to include for regularization. If provided, only modules containing these substrings are included.
|
|
238
|
+
"""
|
|
225
239
|
exclude_modules = exclude_modules or []
|
|
226
240
|
include_modules = include_modules or []
|
|
227
241
|
embedding_layer = getattr(self, embedding_attr, None)
|
|
@@ -268,6 +282,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
268
282
|
existing_reg_ids.add(id(module.weight))
|
|
269
283
|
|
|
270
284
|
def add_reg_loss(self) -> torch.Tensor:
|
|
285
|
+
"""
|
|
286
|
+
Compute the regularization loss based on registered parameters and their respective regularization strengths.
|
|
287
|
+
"""
|
|
271
288
|
reg_loss = torch.tensor(0.0, device=self.device)
|
|
272
289
|
|
|
273
290
|
if self.embedding_l1_reg > 0:
|
|
@@ -289,9 +306,25 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
289
306
|
)
|
|
290
307
|
return reg_loss
|
|
291
308
|
|
|
309
|
+
# todo: support build pairwise/listwise label in input
|
|
292
310
|
def get_input(self, input_data: dict, require_labels: bool = True):
|
|
311
|
+
"""
|
|
312
|
+
Prepare unified input features and labels from the given input data.
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
input_data: Input data dictionary containing 'features' and optionally 'labels', e.g., {'features': {'feat1': [...], 'feat2': [...]}, 'labels': {'label': [...]}}.
|
|
317
|
+
require_labels: Whether labels are required in the input data. Default is True: for training and evaluation with labels.
|
|
318
|
+
|
|
319
|
+
Note:
|
|
320
|
+
target tensor shape will always be (batch_size, num_targets)
|
|
321
|
+
"""
|
|
293
322
|
feature_source = input_data.get("features", {})
|
|
323
|
+
# todo: pairwise/listwise label support
|
|
324
|
+
# "labels": {...} should contain pointwise/pair index/list index/ relevance scores
|
|
325
|
+
# now only have pointwise label support
|
|
294
326
|
label_source = input_data.get("labels")
|
|
327
|
+
|
|
295
328
|
X_input = {}
|
|
296
329
|
for feature in self.all_features:
|
|
297
330
|
if feature.name not in feature_source:
|
|
@@ -307,13 +340,14 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
307
340
|
device=self.device,
|
|
308
341
|
)
|
|
309
342
|
y = None
|
|
343
|
+
# if need labels: training or eval with labels
|
|
310
344
|
if len(self.target_columns) > 0 and (
|
|
311
345
|
require_labels
|
|
312
346
|
or (
|
|
313
347
|
label_source
|
|
314
348
|
and any(name in label_source for name in self.target_columns)
|
|
315
349
|
)
|
|
316
|
-
):
|
|
350
|
+
):
|
|
317
351
|
target_tensors = []
|
|
318
352
|
for target_name in self.target_columns:
|
|
319
353
|
if label_source is None or target_name not in label_source:
|
|
@@ -358,6 +392,10 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
358
392
|
This function will split training data into training and validation sets when:
|
|
359
393
|
1. valid_data is None;
|
|
360
394
|
2. valid_split is provided.
|
|
395
|
+
|
|
396
|
+
Returns:
|
|
397
|
+
train_loader: DataLoader for training data.
|
|
398
|
+
valid_split_data: Validation data dict/dataframe split from training data.
|
|
361
399
|
"""
|
|
362
400
|
if not (0 < valid_split < 1):
|
|
363
401
|
raise ValueError(
|
|
@@ -375,7 +413,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
375
413
|
)
|
|
376
414
|
else:
|
|
377
415
|
raise TypeError(
|
|
378
|
-
f"[BaseModel-validation Error] If you want to use valid_split, train_data must be
|
|
416
|
+
f"[BaseModel-validation Error] If you want to use valid_split, train_data must be DataFrame or a dict, now got {type(train_data)}"
|
|
379
417
|
)
|
|
380
418
|
rng = np.random.default_rng(42)
|
|
381
419
|
indices = rng.permutation(total_length)
|
|
@@ -426,7 +464,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
426
464
|
Args:
|
|
427
465
|
optimizer: Optimizer name or instance. e.g., 'adam', 'sgd', or torch.optim.Adam().
|
|
428
466
|
optimizer_params: Optimizer parameters. e.g., {'lr': 1e-3, 'weight_decay': 1e-5}.
|
|
429
|
-
scheduler: Learning rate scheduler name or instance. e.g., '
|
|
467
|
+
scheduler: Learning rate scheduler name or instance. e.g., 'step', 'cosine', or torch.optim.lr_scheduler.StepLR().
|
|
430
468
|
scheduler_params: Scheduler parameters. e.g., {'step_size': 10, 'gamma': 0.1}.
|
|
431
469
|
loss: Loss function name, instance, or list for multi-task. e.g., 'bce', 'mse', or torch.nn.BCELoss(), you can also use custom loss functions.
|
|
432
470
|
loss_params: Loss function parameters, or list for multi-task. e.g., {'weight': tensor([0.25, 0.75])}.
|
|
@@ -435,36 +473,31 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
435
473
|
ignore_label: Label value to ignore when computing loss. Use this to skip gradients for unknown labels.
|
|
436
474
|
"""
|
|
437
475
|
self.ignore_label = ignore_label
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
}
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
self.loss_params = loss_params or {}
|
|
447
|
-
optimizer_params = optimizer_params or {}
|
|
476
|
+
|
|
477
|
+
# get loss list
|
|
478
|
+
loss_list = get_loss_list(loss, self.training_modes, self.nums_task)
|
|
479
|
+
|
|
480
|
+
self.loss_params = {} if loss_params is None else loss_params
|
|
481
|
+
self.optimizer_params = optimizer_params or {}
|
|
482
|
+
self.scheduler_params = scheduler_params or {}
|
|
483
|
+
|
|
448
484
|
self.optimizer_name = (
|
|
449
485
|
optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
450
486
|
)
|
|
451
|
-
self.optimizer_params = optimizer_params
|
|
452
487
|
self.optimizer_fn = get_optimizer(
|
|
453
488
|
optimizer=optimizer,
|
|
454
489
|
params=self.parameters(),
|
|
455
|
-
**optimizer_params,
|
|
490
|
+
**self.optimizer_params,
|
|
456
491
|
)
|
|
457
492
|
|
|
458
|
-
scheduler_params = scheduler_params or {}
|
|
459
493
|
if scheduler is None:
|
|
460
494
|
self.scheduler_name = None
|
|
461
495
|
elif isinstance(scheduler, str):
|
|
462
496
|
self.scheduler_name = scheduler
|
|
463
497
|
else:
|
|
464
498
|
self.scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__) # type: ignore
|
|
465
|
-
self.scheduler_params = scheduler_params
|
|
466
499
|
self.scheduler_fn = (
|
|
467
|
-
get_scheduler(scheduler, self.optimizer_fn, **scheduler_params)
|
|
500
|
+
get_scheduler(scheduler, self.optimizer_fn, **self.scheduler_params)
|
|
468
501
|
if scheduler
|
|
469
502
|
else None
|
|
470
503
|
)
|
|
@@ -482,35 +515,56 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
482
515
|
for i in range(self.nums_task)
|
|
483
516
|
]
|
|
484
517
|
|
|
518
|
+
# loss weighting (grad norm or fixed weights)
|
|
485
519
|
self.grad_norm = None
|
|
486
520
|
self.grad_norm_shared_params = None
|
|
487
|
-
|
|
521
|
+
is_grad_norm = (
|
|
522
|
+
loss_weights == "grad_norm"
|
|
523
|
+
or isinstance(loss_weights, dict)
|
|
524
|
+
and loss_weights.get("method") == "grad_norm"
|
|
525
|
+
)
|
|
526
|
+
if is_grad_norm:
|
|
488
527
|
if self.nums_task == 1:
|
|
489
528
|
raise ValueError(
|
|
490
529
|
"[BaseModel-compile Error] GradNorm requires multi-task setup."
|
|
491
530
|
)
|
|
492
|
-
|
|
493
|
-
|
|
531
|
+
grad_norm_params = (
|
|
532
|
+
dict(loss_weights) if isinstance(loss_weights, dict) else {}
|
|
494
533
|
)
|
|
495
|
-
self.loss_weights = None
|
|
496
|
-
elif (
|
|
497
|
-
isinstance(loss_weights, dict) and loss_weights.get("method") == "grad_norm"
|
|
498
|
-
):
|
|
499
|
-
if self.nums_task == 1:
|
|
500
|
-
raise ValueError(
|
|
501
|
-
"[BaseModel-compile Error] GradNorm requires multi-task setup."
|
|
502
|
-
)
|
|
503
|
-
grad_norm_params = dict(loss_weights)
|
|
504
534
|
grad_norm_params.pop("method", None)
|
|
505
535
|
self.grad_norm = GradNormLossWeighting(
|
|
506
536
|
nums_task=self.nums_task, device=self.device, **grad_norm_params
|
|
507
537
|
)
|
|
508
538
|
self.loss_weights = None
|
|
539
|
+
elif loss_weights is None:
|
|
540
|
+
self.loss_weights = None
|
|
541
|
+
elif self.nums_task == 1:
|
|
542
|
+
if isinstance(loss_weights, (list, tuple)):
|
|
543
|
+
if len(loss_weights) != 1:
|
|
544
|
+
raise ValueError(
|
|
545
|
+
"[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup."
|
|
546
|
+
)
|
|
547
|
+
loss_weights = loss_weights[0]
|
|
548
|
+
self.loss_weights = [float(loss_weights)]
|
|
549
|
+
elif isinstance(loss_weights, (int, float)):
|
|
550
|
+
self.loss_weights = [float(loss_weights)] * self.nums_task
|
|
551
|
+
elif isinstance(loss_weights, (list, tuple)):
|
|
552
|
+
weights = [float(w) for w in loss_weights]
|
|
553
|
+
if len(weights) != self.nums_task:
|
|
554
|
+
raise ValueError(
|
|
555
|
+
f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task})."
|
|
556
|
+
)
|
|
557
|
+
self.loss_weights = weights
|
|
509
558
|
else:
|
|
510
|
-
|
|
559
|
+
raise TypeError(
|
|
560
|
+
f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}"
|
|
561
|
+
)
|
|
511
562
|
self.compiled = True
|
|
512
563
|
|
|
513
564
|
def compute_loss(self, y_pred, y_true):
|
|
565
|
+
"""
|
|
566
|
+
Compute the loss between predictions and ground truth labels, with loss weighting and ignore_label handling
|
|
567
|
+
"""
|
|
514
568
|
if y_true is None:
|
|
515
569
|
raise ValueError(
|
|
516
570
|
"[BaseModel-compute_loss Error] Ground truth labels (y_true) are required."
|
|
@@ -522,13 +576,11 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
522
576
|
y_pred = y_pred.view(-1, 1)
|
|
523
577
|
if y_true.dim() == 1:
|
|
524
578
|
y_true = y_true.view(-1, 1)
|
|
525
|
-
if y_pred.shape != y_true.shape:
|
|
526
|
-
raise ValueError(
|
|
527
|
-
f"[BaseModel-compute_loss Error] Shape mismatch: {y_pred.shape} vs {y_true.shape}"
|
|
528
|
-
)
|
|
529
579
|
|
|
530
580
|
loss_fn = self.loss_fn[0]
|
|
531
581
|
|
|
582
|
+
# mask ignored labels
|
|
583
|
+
# we don't suggest using ignore_label for single task training
|
|
532
584
|
if self.ignore_label is not None:
|
|
533
585
|
valid_mask = y_true != self.ignore_label
|
|
534
586
|
if valid_mask.dim() > 1:
|
|
@@ -559,9 +611,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
559
611
|
loss *= self.loss_weights[0]
|
|
560
612
|
return loss
|
|
561
613
|
|
|
562
|
-
# multi-task
|
|
563
|
-
if y_pred.shape != y_true.shape:
|
|
564
|
-
raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
|
|
614
|
+
# multi-task: slice predictions and labels per task
|
|
565
615
|
slices = (
|
|
566
616
|
self.prediction_layer.task_slices # type: ignore
|
|
567
617
|
if hasattr(self, "prediction_layer")
|
|
@@ -593,9 +643,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
593
643
|
)
|
|
594
644
|
else:
|
|
595
645
|
task_loss = self.loss_fn[i](y_pred_i, y_true_i)
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
646
|
+
# task_loss = normalize_task_loss(
|
|
647
|
+
# task_loss, valid_count, total_count
|
|
648
|
+
# ) # normalize by valid samples to avoid loss scale issues
|
|
599
649
|
task_losses.append(task_loss)
|
|
600
650
|
|
|
601
651
|
if self.grad_norm is not None:
|
|
@@ -619,11 +669,23 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
619
669
|
batch_size: int = 32,
|
|
620
670
|
shuffle: bool = True,
|
|
621
671
|
num_workers: int = 0,
|
|
672
|
+
prefetch_factor: int | None = None,
|
|
622
673
|
sampler=None,
|
|
623
674
|
return_dataset: bool = False,
|
|
624
675
|
):
|
|
625
676
|
"""
|
|
626
677
|
Prepare a DataLoader from input data. Only used when input data is not a DataLoader.
|
|
678
|
+
|
|
679
|
+
Args:
|
|
680
|
+
data: Input data (dict/df/DataLoader).
|
|
681
|
+
batch_size: Batch size.
|
|
682
|
+
shuffle: Whether to shuffle the data (ignored when a sampler is provided).
|
|
683
|
+
num_workers: Number of DataLoader workers.
|
|
684
|
+
prefetch_factor: Number of batches loaded in advance by each worker.
|
|
685
|
+
sampler: Optional sampler for DataLoader.
|
|
686
|
+
return_dataset: Whether to return the tensor dataset along with the DataLoader, used for valid data
|
|
687
|
+
Returns:
|
|
688
|
+
DataLoader (and tensor dataset if return_dataset is True).
|
|
627
689
|
"""
|
|
628
690
|
if isinstance(data, DataLoader):
|
|
629
691
|
return (data, None) if return_dataset else data
|
|
@@ -639,6 +701,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
639
701
|
"[BaseModel-prepare_data_loader Error] No data available to create DataLoader."
|
|
640
702
|
)
|
|
641
703
|
dataset = TensorDictDataset(tensors)
|
|
704
|
+
loader_kwargs = {}
|
|
705
|
+
if num_workers > 0 and prefetch_factor is not None:
|
|
706
|
+
loader_kwargs["prefetch_factor"] = prefetch_factor
|
|
642
707
|
loader = DataLoader(
|
|
643
708
|
dataset,
|
|
644
709
|
batch_size=batch_size,
|
|
@@ -648,6 +713,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
648
713
|
num_workers=num_workers,
|
|
649
714
|
pin_memory=self.device.type == "cuda",
|
|
650
715
|
persistent_workers=num_workers > 0,
|
|
716
|
+
**loader_kwargs,
|
|
651
717
|
)
|
|
652
718
|
return (loader, dataset) if return_dataset else loader
|
|
653
719
|
|
|
@@ -676,6 +742,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
676
742
|
swanlab_kwargs: dict | None = None,
|
|
677
743
|
auto_ddp_sampler: bool = True,
|
|
678
744
|
log_interval: int = 1,
|
|
745
|
+
note: str | None = None,
|
|
679
746
|
summary_sections: (
|
|
680
747
|
list[Literal["feature", "model", "train", "data"]] | None
|
|
681
748
|
) = None,
|
|
@@ -707,6 +774,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
707
774
|
swanlab_kwargs: Optional kwargs for swanlab.init(...).
|
|
708
775
|
auto_ddp_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
|
|
709
776
|
log_interval: Log validation metrics every N epochs (still computes metrics each epoch).
|
|
777
|
+
note: Optional note for the training run.
|
|
710
778
|
summary_sections: Optional summary sections to print. Choose from
|
|
711
779
|
["feature", "model", "train", "data"]. Defaults to all.
|
|
712
780
|
|
|
@@ -720,6 +788,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
720
788
|
)
|
|
721
789
|
self.to(self.device)
|
|
722
790
|
|
|
791
|
+
assert_task(self.task, len(self.target_columns), model_name=self.model_name)
|
|
792
|
+
|
|
723
793
|
if not self.compiled:
|
|
724
794
|
self.compile(
|
|
725
795
|
optimizer="adam",
|
|
@@ -770,11 +840,13 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
770
840
|
self.metrics_sample_limit = (
|
|
771
841
|
None if metrics_sample_limit is None else int(metrics_sample_limit)
|
|
772
842
|
)
|
|
843
|
+
self.note = note
|
|
773
844
|
|
|
774
845
|
training_config = {}
|
|
775
846
|
if self.is_main_process:
|
|
776
847
|
training_config = {
|
|
777
848
|
"model_name": getattr(self, "model_name", self.__class__.__name__),
|
|
849
|
+
"note": self.note,
|
|
778
850
|
"task": self.task,
|
|
779
851
|
"target_columns": self.target_columns,
|
|
780
852
|
"batch_size": batch_size,
|
|
@@ -822,6 +894,14 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
822
894
|
else:
|
|
823
895
|
swanlab.login(api_key=swanlab_api)
|
|
824
896
|
|
|
897
|
+
if use_wandb and self.note:
|
|
898
|
+
wandb_kwargs = dict(wandb_kwargs or {})
|
|
899
|
+
wandb_kwargs.setdefault("notes", self.note)
|
|
900
|
+
|
|
901
|
+
if use_swanlab and self.note:
|
|
902
|
+
swanlab_kwargs = dict(swanlab_kwargs or {})
|
|
903
|
+
swanlab_kwargs.setdefault("description", self.note)
|
|
904
|
+
|
|
825
905
|
self.training_logger = (
|
|
826
906
|
TrainingLogger(
|
|
827
907
|
session=self.session,
|
|
@@ -1253,7 +1333,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1253
1333
|
for batch_index, batch_data in batch_iter:
|
|
1254
1334
|
batch_dict = batch_to_dict(batch_data)
|
|
1255
1335
|
X_input, y_true = self.get_input(batch_dict, require_labels=True)
|
|
1256
|
-
# call via __call__ so DDP hooks run
|
|
1336
|
+
# call via __call__ so DDP hooks run
|
|
1257
1337
|
y_pred = model(X_input) # type: ignore
|
|
1258
1338
|
|
|
1259
1339
|
loss = self.compute_loss(y_pred, y_true)
|
|
@@ -1556,7 +1636,6 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1556
1636
|
num_workers: int = 0,
|
|
1557
1637
|
) -> pd.DataFrame | np.ndarray | Path | None:
|
|
1558
1638
|
"""
|
|
1559
|
-
Note: predict does not support distributed mode currently, consider it as a single-process operation.
|
|
1560
1639
|
Make predictions on the given data.
|
|
1561
1640
|
|
|
1562
1641
|
Args:
|
|
@@ -1569,6 +1648,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1569
1648
|
return_dataframe: Whether to return predictions as a pandas DataFrame; if False, returns a NumPy array.
|
|
1570
1649
|
stream_chunk_size: Number of rows per chunk when using streaming mode for large datasets.
|
|
1571
1650
|
num_workers: DataLoader worker count.
|
|
1651
|
+
|
|
1652
|
+
Note:
|
|
1653
|
+
predict does not support distributed mode currently, consider it as a single-process operation.
|
|
1572
1654
|
"""
|
|
1573
1655
|
self.eval()
|
|
1574
1656
|
# Use prediction-time id_columns if provided, otherwise fall back to model's id_columns
|
|
@@ -1753,6 +1835,21 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1753
1835
|
return_dataframe: bool,
|
|
1754
1836
|
id_columns: list[str] | None = None,
|
|
1755
1837
|
):
|
|
1838
|
+
"""
|
|
1839
|
+
Make predictions on the given data using streaming mode for large datasets.
|
|
1840
|
+
|
|
1841
|
+
Args:
|
|
1842
|
+
data: Input data for prediction (file path, dict, DataFrame, or DataLoader).
|
|
1843
|
+
batch_size: Batch size for prediction.
|
|
1844
|
+
save_path: Path to save predictions.
|
|
1845
|
+
save_format: Format to save predictions ('csv' or 'parquet').
|
|
1846
|
+
include_ids: Whether to include ID columns in the output.
|
|
1847
|
+
stream_chunk_size: Number of rows per chunk when using streaming mode.
|
|
1848
|
+
return_dataframe: Whether to return predictions as a pandas DataFrame.
|
|
1849
|
+
id_columns: Column name(s) to use as IDs; if None, uses model's id_columns.
|
|
1850
|
+
Note:
|
|
1851
|
+
This method uses streaming writes to handle large datasets without loading all data into memory.
|
|
1852
|
+
"""
|
|
1756
1853
|
if isinstance(data, (str, os.PathLike)):
|
|
1757
1854
|
rec_loader = RecDataLoader(
|
|
1758
1855
|
dense_features=self.dense_features,
|
|
@@ -1795,8 +1892,6 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1795
1892
|
"Results will be collected in memory and saved at the end. Use csv or parquet for true streaming."
|
|
1796
1893
|
)
|
|
1797
1894
|
|
|
1798
|
-
from nextrec.utils.data import FILE_FORMAT_CONFIG
|
|
1799
|
-
|
|
1800
1895
|
suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
|
|
1801
1896
|
|
|
1802
1897
|
target_path = get_save_path(
|
|
@@ -1908,6 +2003,14 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1908
2003
|
add_timestamp: bool | None = None,
|
|
1909
2004
|
verbose: bool = True,
|
|
1910
2005
|
):
|
|
2006
|
+
"""
|
|
2007
|
+
Save the model state and features configuration to disk.
|
|
2008
|
+
|
|
2009
|
+
Args:
|
|
2010
|
+
save_path: Path to save the model; if None, saves to the session's model directory.
|
|
2011
|
+
add_timestamp: Whether to add a timestamp to the filename; if None, defaults to True.
|
|
2012
|
+
verbose: Whether to log the save location.
|
|
2013
|
+
"""
|
|
1911
2014
|
add_timestamp = False if add_timestamp is None else add_timestamp
|
|
1912
2015
|
target_path = get_save_path(
|
|
1913
2016
|
path=save_path,
|
|
@@ -1950,6 +2053,14 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1950
2053
|
map_location: str | torch.device | None = "cpu",
|
|
1951
2054
|
verbose: bool = True,
|
|
1952
2055
|
):
|
|
2056
|
+
"""
|
|
2057
|
+
Load the model state and features configuration from disk.
|
|
2058
|
+
|
|
2059
|
+
Args:
|
|
2060
|
+
save_path: Path to load the model from; can be a directory or a specific .pt file.
|
|
2061
|
+
map_location: Device mapping for loading the model (e.g., 'cpu', 'cuda:0').
|
|
2062
|
+
verbose: Whether to log the load location.
|
|
2063
|
+
"""
|
|
1953
2064
|
self.to(self.device)
|
|
1954
2065
|
base_path = Path(save_path)
|
|
1955
2066
|
if base_path.is_dir():
|
|
@@ -2016,6 +2127,13 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
2016
2127
|
"""
|
|
2017
2128
|
Load a model from a checkpoint path. The checkpoint path should contain:
|
|
2018
2129
|
a .pt file and a features_config.pkl file.
|
|
2130
|
+
|
|
2131
|
+
Args:
|
|
2132
|
+
checkpoint_path: Path to the checkpoint directory or specific .pt file.
|
|
2133
|
+
map_location: Device mapping for loading the model (e.g., 'cpu', 'cuda:0').
|
|
2134
|
+
device: Device to place the model on after loading.
|
|
2135
|
+
session_id: Optional session ID for the model.
|
|
2136
|
+
**kwargs: Additional keyword arguments to pass to the model constructor.
|
|
2019
2137
|
"""
|
|
2020
2138
|
base_path = Path(checkpoint_path)
|
|
2021
2139
|
verbose = kwargs.pop("verbose", True)
|
|
@@ -2135,6 +2253,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2135
2253
|
target=target,
|
|
2136
2254
|
id_columns=id_columns,
|
|
2137
2255
|
task=task,
|
|
2256
|
+
training_mode=training_mode,
|
|
2138
2257
|
device=device,
|
|
2139
2258
|
embedding_l1_reg=embedding_l1_reg,
|
|
2140
2259
|
dense_l1_reg=dense_l1_reg,
|
|
@@ -2157,10 +2276,14 @@ class BaseMatchModel(BaseModel):
|
|
|
2157
2276
|
self.item_sparse_features = item_sparse_features
|
|
2158
2277
|
self.item_sequence_features = item_sequence_features
|
|
2159
2278
|
|
|
2160
|
-
self.training_mode = training_mode
|
|
2161
2279
|
self.num_negative_samples = num_negative_samples
|
|
2162
2280
|
self.temperature = temperature
|
|
2163
2281
|
self.similarity_metric = similarity_metric
|
|
2282
|
+
primary_mode = self.training_modes[0] if self.training_modes else "pointwise"
|
|
2283
|
+
if primary_mode not in self.support_training_modes:
|
|
2284
|
+
raise ValueError(
|
|
2285
|
+
f"{self.model_name.upper()} does not support training_mode='{primary_mode}'. Supported modes: {self.support_training_modes}"
|
|
2286
|
+
)
|
|
2164
2287
|
self.user_features_all = (
|
|
2165
2288
|
self.user_dense_features
|
|
2166
2289
|
+ self.user_sparse_features
|
|
@@ -2176,7 +2299,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2176
2299
|
self.head = RetrievalHead(
|
|
2177
2300
|
similarity_metric=self.similarity_metric,
|
|
2178
2301
|
temperature=self.temperature,
|
|
2179
|
-
training_mode=
|
|
2302
|
+
training_mode=primary_mode,
|
|
2180
2303
|
apply_sigmoid=True,
|
|
2181
2304
|
)
|
|
2182
2305
|
|
|
@@ -2209,11 +2332,6 @@ class BaseMatchModel(BaseModel):
|
|
|
2209
2332
|
loss_params: Parameters for the loss function(s). e.g., {'reduction': 'mean'}.
|
|
2210
2333
|
loss_weights: Weights for the loss function(s). e.g., 1.0 or [0.7, 0.3].
|
|
2211
2334
|
"""
|
|
2212
|
-
if self.training_mode not in self.support_training_modes:
|
|
2213
|
-
raise ValueError(
|
|
2214
|
-
f"{self.model_name.upper()} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
|
|
2215
|
-
)
|
|
2216
|
-
|
|
2217
2335
|
default_loss_by_mode = {
|
|
2218
2336
|
"pointwise": "bce",
|
|
2219
2337
|
"pairwise": "bpr",
|
|
@@ -2221,26 +2339,27 @@ class BaseMatchModel(BaseModel):
|
|
|
2221
2339
|
}
|
|
2222
2340
|
|
|
2223
2341
|
effective_loss = loss
|
|
2342
|
+
primary_mode = self.training_modes[0] if self.training_modes else "pointwise"
|
|
2224
2343
|
if effective_loss is None:
|
|
2225
|
-
effective_loss = default_loss_by_mode[
|
|
2344
|
+
effective_loss = default_loss_by_mode[primary_mode]
|
|
2226
2345
|
elif isinstance(effective_loss, str):
|
|
2227
|
-
if
|
|
2346
|
+
if primary_mode in {"pairwise", "listwise"} and effective_loss in {
|
|
2228
2347
|
"bce",
|
|
2229
2348
|
"binary_crossentropy",
|
|
2230
2349
|
}:
|
|
2231
|
-
effective_loss = default_loss_by_mode[
|
|
2350
|
+
effective_loss = default_loss_by_mode[primary_mode]
|
|
2232
2351
|
elif isinstance(effective_loss, list):
|
|
2233
2352
|
if not effective_loss:
|
|
2234
|
-
effective_loss = [default_loss_by_mode[
|
|
2353
|
+
effective_loss = [default_loss_by_mode[primary_mode]]
|
|
2235
2354
|
else:
|
|
2236
2355
|
first = effective_loss[0]
|
|
2237
2356
|
if (
|
|
2238
|
-
|
|
2357
|
+
primary_mode in {"pairwise", "listwise"}
|
|
2239
2358
|
and isinstance(first, str)
|
|
2240
2359
|
and first in {"bce", "binary_crossentropy"}
|
|
2241
2360
|
):
|
|
2242
2361
|
effective_loss = [
|
|
2243
|
-
default_loss_by_mode[
|
|
2362
|
+
default_loss_by_mode[primary_mode],
|
|
2244
2363
|
*effective_loss[1:],
|
|
2245
2364
|
]
|
|
2246
2365
|
return super().compile(
|
|
@@ -2318,11 +2437,12 @@ class BaseMatchModel(BaseModel):
|
|
|
2318
2437
|
return self.head(user_emb, item_emb, similarity_fn=self.compute_similarity)
|
|
2319
2438
|
|
|
2320
2439
|
def compute_loss(self, y_pred, y_true):
|
|
2321
|
-
if self.
|
|
2440
|
+
primary_mode = self.training_modes[0] if self.training_modes else "pointwise"
|
|
2441
|
+
if primary_mode == "pointwise":
|
|
2322
2442
|
return super().compute_loss(y_pred, y_true)
|
|
2323
2443
|
|
|
2324
2444
|
# pairwise / listwise using inbatch neg
|
|
2325
|
-
elif
|
|
2445
|
+
elif primary_mode in ["pairwise", "listwise"]:
|
|
2326
2446
|
if not isinstance(y_pred, (tuple, list)) or len(y_pred) != 2:
|
|
2327
2447
|
raise ValueError(
|
|
2328
2448
|
"For pairwise/listwise training, forward should return (user_emb, item_emb). Please check BaseMatchModel.forward implementation."
|
|
@@ -2365,7 +2485,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2365
2485
|
loss *= float(self.loss_weights[0])
|
|
2366
2486
|
return loss
|
|
2367
2487
|
else:
|
|
2368
|
-
raise ValueError(f"Unknown training mode: {
|
|
2488
|
+
raise ValueError(f"Unknown training mode: {primary_mode}")
|
|
2369
2489
|
|
|
2370
2490
|
def prepare_feature_data(
|
|
2371
2491
|
self,
|