nextrec 0.4.23__py3-none-any.whl → 0.4.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/layers.py +96 -46
- nextrec/basic/metrics.py +128 -113
- nextrec/basic/model.py +201 -76
- nextrec/basic/summary.py +58 -0
- nextrec/cli.py +13 -0
- nextrec/data/data_processing.py +3 -9
- nextrec/data/dataloader.py +27 -2
- nextrec/data/preprocessor.py +283 -36
- nextrec/models/multi_task/aitm.py +0 -0
- nextrec/models/multi_task/apg.py +0 -0
- nextrec/models/multi_task/cross_stitch.py +0 -0
- nextrec/models/multi_task/esmm.py +2 -2
- nextrec/models/multi_task/mmoe.py +4 -4
- nextrec/models/multi_task/pepnet.py +335 -0
- nextrec/models/multi_task/ple.py +8 -5
- nextrec/models/multi_task/poso.py +13 -11
- nextrec/models/multi_task/share_bottom.py +4 -4
- nextrec/models/multi_task/snr_trans.py +0 -0
- nextrec/models/ranking/dcn_v2.py +1 -1
- nextrec/models/retrieval/dssm.py +4 -4
- nextrec/models/retrieval/dssm_v2.py +4 -4
- nextrec/models/retrieval/mind.py +2 -2
- nextrec/models/retrieval/sdm.py +4 -4
- nextrec/models/retrieval/youtube_dnn.py +4 -4
- nextrec/utils/config.py +2 -0
- nextrec/utils/model.py +17 -64
- nextrec/utils/torch_utils.py +11 -0
- {nextrec-0.4.23.dist-info → nextrec-0.4.25.dist-info}/METADATA +5 -5
- {nextrec-0.4.23.dist-info → nextrec-0.4.25.dist-info}/RECORD +33 -28
- {nextrec-0.4.23.dist-info → nextrec-0.4.25.dist-info}/WHEEL +0 -0
- {nextrec-0.4.23.dist-info → nextrec-0.4.25.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.23.dist-info → nextrec-0.4.25.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 31/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -88,9 +88,8 @@ from nextrec.utils.config import safe_value
|
|
|
88
88
|
from nextrec.utils.model import (
|
|
89
89
|
compute_ranking_loss,
|
|
90
90
|
get_loss_list,
|
|
91
|
-
resolve_loss_weights,
|
|
92
|
-
get_training_modes,
|
|
93
91
|
)
|
|
92
|
+
|
|
94
93
|
from nextrec.utils.types import (
|
|
95
94
|
LossName,
|
|
96
95
|
OptimizerName,
|
|
@@ -100,6 +99,7 @@ from nextrec.utils.types import (
|
|
|
100
99
|
MetricsName,
|
|
101
100
|
)
|
|
102
101
|
|
|
102
|
+
from nextrec.utils.data import FILE_FORMAT_CONFIG
|
|
103
103
|
|
|
104
104
|
class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
105
105
|
@property
|
|
@@ -110,6 +110,30 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
110
110
|
def default_task(self) -> TaskTypeName | list[TaskTypeName]:
|
|
111
111
|
raise NotImplementedError
|
|
112
112
|
|
|
113
|
+
@property
|
|
114
|
+
def training_mode(self) -> TrainingModeName | list[TrainingModeName]:
|
|
115
|
+
if self.nums_task > 1:
|
|
116
|
+
return self.training_modes
|
|
117
|
+
return self.training_modes[0] if self.training_modes else "pointwise"
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@training_mode.setter
|
|
121
|
+
def training_mode(self, training_mode: TrainingModeName | list[TrainingModeName]):
|
|
122
|
+
valid_modes = {"pointwise", "pairwise", "listwise"}
|
|
123
|
+
if isinstance(training_mode, list):
|
|
124
|
+
training_modes = list(training_mode)
|
|
125
|
+
if len(training_modes) != self.nums_task:
|
|
126
|
+
raise ValueError(
|
|
127
|
+
"[BaseModel-init Error] training_mode list length must match number of tasks."
|
|
128
|
+
)
|
|
129
|
+
else:
|
|
130
|
+
training_modes = [training_mode] * self.nums_task
|
|
131
|
+
if any(mode not in valid_modes for mode in training_modes):
|
|
132
|
+
raise ValueError(
|
|
133
|
+
"[BaseModel-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
|
|
134
|
+
)
|
|
135
|
+
self.training_modes = list(training_modes)
|
|
136
|
+
|
|
113
137
|
def __init__(
|
|
114
138
|
self,
|
|
115
139
|
dense_features: list[DenseFeature] | None = None,
|
|
@@ -193,10 +217,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
193
217
|
|
|
194
218
|
self.task = task or self.default_task
|
|
195
219
|
self.nums_task = len(self.task) if isinstance(self.task, list) else 1
|
|
196
|
-
|
|
197
|
-
self.training_mode =
|
|
198
|
-
self.training_modes if self.nums_task > 1 else self.training_modes[0]
|
|
199
|
-
)
|
|
220
|
+
|
|
221
|
+
self.training_mode = training_mode
|
|
200
222
|
|
|
201
223
|
self.embedding_l1_reg = embedding_l1_reg
|
|
202
224
|
self.dense_l1_reg = dense_l1_reg
|
|
@@ -215,6 +237,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
215
237
|
|
|
216
238
|
self.train_data_summary = None
|
|
217
239
|
self.valid_data_summary = None
|
|
240
|
+
self.note = None
|
|
218
241
|
|
|
219
242
|
def register_regularization_weights(
|
|
220
243
|
self,
|
|
@@ -222,6 +245,15 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
222
245
|
exclude_modules: list[str] | None = None,
|
|
223
246
|
include_modules: list[str] | None = None,
|
|
224
247
|
):
|
|
248
|
+
"""
|
|
249
|
+
Register parameters for regularization.
|
|
250
|
+
By default, all nn.Linear weights (excluding those in BatchNorm/Dropout layers) and embedding weights under `embedding_attr` are registered.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
embedding_attr: Attribute name of the embedding layer/module.
|
|
254
|
+
exclude_modules: List of module name substrings to exclude from regularization.
|
|
255
|
+
include_modules: List of module name substrings to include for regularization. If provided, only modules containing these substrings are included.
|
|
256
|
+
"""
|
|
225
257
|
exclude_modules = exclude_modules or []
|
|
226
258
|
include_modules = include_modules or []
|
|
227
259
|
embedding_layer = getattr(self, embedding_attr, None)
|
|
@@ -268,6 +300,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
268
300
|
existing_reg_ids.add(id(module.weight))
|
|
269
301
|
|
|
270
302
|
def add_reg_loss(self) -> torch.Tensor:
|
|
303
|
+
"""
|
|
304
|
+
Compute the regularization loss based on registered parameters and their respective regularization strengths.
|
|
305
|
+
"""
|
|
271
306
|
reg_loss = torch.tensor(0.0, device=self.device)
|
|
272
307
|
|
|
273
308
|
if self.embedding_l1_reg > 0:
|
|
@@ -289,9 +324,25 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
289
324
|
)
|
|
290
325
|
return reg_loss
|
|
291
326
|
|
|
327
|
+
# todo: support build pairwise/listwise label in input
|
|
292
328
|
def get_input(self, input_data: dict, require_labels: bool = True):
|
|
329
|
+
"""
|
|
330
|
+
Prepare unified input features and labels from the given input data.
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
input_data: Input data dictionary containing 'features' and optionally 'labels', e.g., {'features': {'feat1': [...], 'feat2': [...]}, 'labels': {'label': [...]}}.
|
|
335
|
+
require_labels: Whether labels are required in the input data. Default is True: for training and evaluation with labels.
|
|
336
|
+
|
|
337
|
+
Note:
|
|
338
|
+
target tensor shape will always be (batch_size, num_targets)
|
|
339
|
+
"""
|
|
293
340
|
feature_source = input_data.get("features", {})
|
|
341
|
+
# todo: pairwise/listwise label support
|
|
342
|
+
# "labels": {...} should contain pointwise/pair index/list index/ relevance scores
|
|
343
|
+
# now only have pointwise label support
|
|
294
344
|
label_source = input_data.get("labels")
|
|
345
|
+
|
|
295
346
|
X_input = {}
|
|
296
347
|
for feature in self.all_features:
|
|
297
348
|
if feature.name not in feature_source:
|
|
@@ -307,13 +358,14 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
307
358
|
device=self.device,
|
|
308
359
|
)
|
|
309
360
|
y = None
|
|
361
|
+
# if need labels: training or eval with labels
|
|
310
362
|
if len(self.target_columns) > 0 and (
|
|
311
363
|
require_labels
|
|
312
364
|
or (
|
|
313
365
|
label_source
|
|
314
366
|
and any(name in label_source for name in self.target_columns)
|
|
315
367
|
)
|
|
316
|
-
):
|
|
368
|
+
):
|
|
317
369
|
target_tensors = []
|
|
318
370
|
for target_name in self.target_columns:
|
|
319
371
|
if label_source is None or target_name not in label_source:
|
|
@@ -358,6 +410,10 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
358
410
|
This function will split training data into training and validation sets when:
|
|
359
411
|
1. valid_data is None;
|
|
360
412
|
2. valid_split is provided.
|
|
413
|
+
|
|
414
|
+
Returns:
|
|
415
|
+
train_loader: DataLoader for training data.
|
|
416
|
+
valid_split_data: Validation data dict/dataframe split from training data.
|
|
361
417
|
"""
|
|
362
418
|
if not (0 < valid_split < 1):
|
|
363
419
|
raise ValueError(
|
|
@@ -375,7 +431,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
375
431
|
)
|
|
376
432
|
else:
|
|
377
433
|
raise TypeError(
|
|
378
|
-
f"[BaseModel-validation Error] If you want to use valid_split, train_data must be
|
|
434
|
+
f"[BaseModel-validation Error] If you want to use valid_split, train_data must be DataFrame or a dict, now got {type(train_data)}"
|
|
379
435
|
)
|
|
380
436
|
rng = np.random.default_rng(42)
|
|
381
437
|
indices = rng.permutation(total_length)
|
|
@@ -426,7 +482,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
426
482
|
Args:
|
|
427
483
|
optimizer: Optimizer name or instance. e.g., 'adam', 'sgd', or torch.optim.Adam().
|
|
428
484
|
optimizer_params: Optimizer parameters. e.g., {'lr': 1e-3, 'weight_decay': 1e-5}.
|
|
429
|
-
scheduler: Learning rate scheduler name or instance. e.g., '
|
|
485
|
+
scheduler: Learning rate scheduler name or instance. e.g., 'step', 'cosine', or torch.optim.lr_scheduler.StepLR().
|
|
430
486
|
scheduler_params: Scheduler parameters. e.g., {'step_size': 10, 'gamma': 0.1}.
|
|
431
487
|
loss: Loss function name, instance, or list for multi-task. e.g., 'bce', 'mse', or torch.nn.BCELoss(), you can also use custom loss functions.
|
|
432
488
|
loss_params: Loss function parameters, or list for multi-task. e.g., {'weight': tensor([0.25, 0.75])}.
|
|
@@ -435,36 +491,31 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
435
491
|
ignore_label: Label value to ignore when computing loss. Use this to skip gradients for unknown labels.
|
|
436
492
|
"""
|
|
437
493
|
self.ignore_label = ignore_label
|
|
438
|
-
default_losses = {
|
|
439
|
-
"pointwise": "bce",
|
|
440
|
-
"pairwise": "bpr",
|
|
441
|
-
"listwise": "listnet",
|
|
442
|
-
}
|
|
443
494
|
loss_list = get_loss_list(
|
|
444
|
-
loss, self.training_modes, self.nums_task
|
|
495
|
+
loss, self.training_modes, self.nums_task
|
|
445
496
|
)
|
|
446
|
-
|
|
447
|
-
|
|
497
|
+
|
|
498
|
+
self.loss_params = {} if loss_params is None else loss_params
|
|
499
|
+
self.optimizer_params = optimizer_params or {}
|
|
500
|
+
self.scheduler_params = scheduler_params or {}
|
|
501
|
+
|
|
448
502
|
self.optimizer_name = (
|
|
449
503
|
optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
450
504
|
)
|
|
451
|
-
self.optimizer_params = optimizer_params
|
|
452
505
|
self.optimizer_fn = get_optimizer(
|
|
453
506
|
optimizer=optimizer,
|
|
454
507
|
params=self.parameters(),
|
|
455
|
-
**optimizer_params,
|
|
508
|
+
**self.optimizer_params,
|
|
456
509
|
)
|
|
457
510
|
|
|
458
|
-
scheduler_params = scheduler_params or {}
|
|
459
511
|
if scheduler is None:
|
|
460
512
|
self.scheduler_name = None
|
|
461
513
|
elif isinstance(scheduler, str):
|
|
462
514
|
self.scheduler_name = scheduler
|
|
463
515
|
else:
|
|
464
516
|
self.scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__) # type: ignore
|
|
465
|
-
self.scheduler_params = scheduler_params
|
|
466
517
|
self.scheduler_fn = (
|
|
467
|
-
get_scheduler(scheduler, self.optimizer_fn, **scheduler_params)
|
|
518
|
+
get_scheduler(scheduler, self.optimizer_fn, **self.scheduler_params)
|
|
468
519
|
if scheduler
|
|
469
520
|
else None
|
|
470
521
|
)
|
|
@@ -482,35 +533,54 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
482
533
|
for i in range(self.nums_task)
|
|
483
534
|
]
|
|
484
535
|
|
|
536
|
+
# loss weighting (grad norm or fixed weights)
|
|
485
537
|
self.grad_norm = None
|
|
486
538
|
self.grad_norm_shared_params = None
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
nums_task=self.nums_task, device=self.device
|
|
494
|
-
)
|
|
495
|
-
self.loss_weights = None
|
|
496
|
-
elif (
|
|
497
|
-
isinstance(loss_weights, dict) and loss_weights.get("method") == "grad_norm"
|
|
498
|
-
):
|
|
539
|
+
is_grad_norm = (
|
|
540
|
+
loss_weights == "grad_norm"
|
|
541
|
+
or isinstance(loss_weights, dict)
|
|
542
|
+
and loss_weights.get("method") == "grad_norm"
|
|
543
|
+
)
|
|
544
|
+
if is_grad_norm:
|
|
499
545
|
if self.nums_task == 1:
|
|
500
546
|
raise ValueError(
|
|
501
547
|
"[BaseModel-compile Error] GradNorm requires multi-task setup."
|
|
502
548
|
)
|
|
503
|
-
grad_norm_params = dict(loss_weights)
|
|
549
|
+
grad_norm_params = dict(loss_weights) if isinstance(loss_weights, dict) else {}
|
|
504
550
|
grad_norm_params.pop("method", None)
|
|
505
551
|
self.grad_norm = GradNormLossWeighting(
|
|
506
552
|
nums_task=self.nums_task, device=self.device, **grad_norm_params
|
|
507
553
|
)
|
|
508
554
|
self.loss_weights = None
|
|
555
|
+
elif loss_weights is None:
|
|
556
|
+
self.loss_weights = None
|
|
557
|
+
elif self.nums_task == 1:
|
|
558
|
+
if isinstance(loss_weights, (list, tuple)):
|
|
559
|
+
if len(loss_weights) != 1:
|
|
560
|
+
raise ValueError(
|
|
561
|
+
"[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup."
|
|
562
|
+
)
|
|
563
|
+
loss_weights = loss_weights[0]
|
|
564
|
+
self.loss_weights = [float(loss_weights)]
|
|
565
|
+
elif isinstance(loss_weights, (int, float)):
|
|
566
|
+
self.loss_weights = [float(loss_weights)] * self.nums_task
|
|
567
|
+
elif isinstance(loss_weights, (list, tuple)):
|
|
568
|
+
weights = [float(w) for w in loss_weights]
|
|
569
|
+
if len(weights) != self.nums_task:
|
|
570
|
+
raise ValueError(
|
|
571
|
+
f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task})."
|
|
572
|
+
)
|
|
573
|
+
self.loss_weights = weights
|
|
509
574
|
else:
|
|
510
|
-
|
|
575
|
+
raise TypeError(
|
|
576
|
+
f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}"
|
|
577
|
+
)
|
|
511
578
|
self.compiled = True
|
|
512
579
|
|
|
513
580
|
def compute_loss(self, y_pred, y_true):
|
|
581
|
+
"""
|
|
582
|
+
Compute the loss between predictions and ground truth labels, with loss weighting and ignore_label handling
|
|
583
|
+
"""
|
|
514
584
|
if y_true is None:
|
|
515
585
|
raise ValueError(
|
|
516
586
|
"[BaseModel-compute_loss Error] Ground truth labels (y_true) are required."
|
|
@@ -522,13 +592,11 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
522
592
|
y_pred = y_pred.view(-1, 1)
|
|
523
593
|
if y_true.dim() == 1:
|
|
524
594
|
y_true = y_true.view(-1, 1)
|
|
525
|
-
if y_pred.shape != y_true.shape:
|
|
526
|
-
raise ValueError(
|
|
527
|
-
f"[BaseModel-compute_loss Error] Shape mismatch: {y_pred.shape} vs {y_true.shape}"
|
|
528
|
-
)
|
|
529
595
|
|
|
530
596
|
loss_fn = self.loss_fn[0]
|
|
531
|
-
|
|
597
|
+
|
|
598
|
+
# mask ignored labels
|
|
599
|
+
# we don't suggest using ignore_label for single task training
|
|
532
600
|
if self.ignore_label is not None:
|
|
533
601
|
valid_mask = y_true != self.ignore_label
|
|
534
602
|
if valid_mask.dim() > 1:
|
|
@@ -559,9 +627,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
559
627
|
loss *= self.loss_weights[0]
|
|
560
628
|
return loss
|
|
561
629
|
|
|
562
|
-
# multi-task
|
|
563
|
-
if y_pred.shape != y_true.shape:
|
|
564
|
-
raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
|
|
630
|
+
# multi-task: slice predictions and labels per task
|
|
565
631
|
slices = (
|
|
566
632
|
self.prediction_layer.task_slices # type: ignore
|
|
567
633
|
if hasattr(self, "prediction_layer")
|
|
@@ -593,9 +659,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
593
659
|
)
|
|
594
660
|
else:
|
|
595
661
|
task_loss = self.loss_fn[i](y_pred_i, y_true_i)
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
662
|
+
# task_loss = normalize_task_loss(
|
|
663
|
+
# task_loss, valid_count, total_count
|
|
664
|
+
# ) # normalize by valid samples to avoid loss scale issues
|
|
599
665
|
task_losses.append(task_loss)
|
|
600
666
|
|
|
601
667
|
if self.grad_norm is not None:
|
|
@@ -624,6 +690,16 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
624
690
|
):
|
|
625
691
|
"""
|
|
626
692
|
Prepare a DataLoader from input data. Only used when input data is not a DataLoader.
|
|
693
|
+
|
|
694
|
+
Args:
|
|
695
|
+
data: Input data (dict/df/DataLoader).
|
|
696
|
+
batch_size: Batch size.
|
|
697
|
+
shuffle: Whether to shuffle the data (ignored when a sampler is provided).
|
|
698
|
+
num_workers: Number of DataLoader workers.
|
|
699
|
+
sampler: Optional sampler for DataLoader.
|
|
700
|
+
return_dataset: Whether to return the tensor dataset along with the DataLoader, used for valid data
|
|
701
|
+
Returns:
|
|
702
|
+
DataLoader (and tensor dataset if return_dataset is True).
|
|
627
703
|
"""
|
|
628
704
|
if isinstance(data, DataLoader):
|
|
629
705
|
return (data, None) if return_dataset else data
|
|
@@ -646,6 +722,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
646
722
|
sampler=sampler,
|
|
647
723
|
collate_fn=collate_fn,
|
|
648
724
|
num_workers=num_workers,
|
|
725
|
+
pin_memory=self.device.type == "cuda",
|
|
726
|
+
persistent_workers=num_workers > 0,
|
|
649
727
|
)
|
|
650
728
|
return (loader, dataset) if return_dataset else loader
|
|
651
729
|
|
|
@@ -674,6 +752,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
674
752
|
swanlab_kwargs: dict | None = None,
|
|
675
753
|
auto_ddp_sampler: bool = True,
|
|
676
754
|
log_interval: int = 1,
|
|
755
|
+
note: str | None = None,
|
|
677
756
|
summary_sections: (
|
|
678
757
|
list[Literal["feature", "model", "train", "data"]] | None
|
|
679
758
|
) = None,
|
|
@@ -705,6 +784,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
705
784
|
swanlab_kwargs: Optional kwargs for swanlab.init(...).
|
|
706
785
|
auto_ddp_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
|
|
707
786
|
log_interval: Log validation metrics every N epochs (still computes metrics each epoch).
|
|
787
|
+
note: Optional note for the training run.
|
|
708
788
|
summary_sections: Optional summary sections to print. Choose from
|
|
709
789
|
["feature", "model", "train", "data"]. Defaults to all.
|
|
710
790
|
|
|
@@ -768,11 +848,13 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
768
848
|
self.metrics_sample_limit = (
|
|
769
849
|
None if metrics_sample_limit is None else int(metrics_sample_limit)
|
|
770
850
|
)
|
|
851
|
+
self.note = note
|
|
771
852
|
|
|
772
853
|
training_config = {}
|
|
773
854
|
if self.is_main_process:
|
|
774
855
|
training_config = {
|
|
775
856
|
"model_name": getattr(self, "model_name", self.__class__.__name__),
|
|
857
|
+
"note": self.note,
|
|
776
858
|
"task": self.task,
|
|
777
859
|
"target_columns": self.target_columns,
|
|
778
860
|
"batch_size": batch_size,
|
|
@@ -1119,16 +1201,17 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1119
1201
|
train_log_payload, step=epoch + 1, split="train"
|
|
1120
1202
|
)
|
|
1121
1203
|
if valid_loader is not None:
|
|
1122
|
-
|
|
1123
|
-
val_metrics = self.evaluate(
|
|
1124
|
-
valid_loader,
|
|
1125
|
-
user_ids=valid_user_ids if self.needs_user_ids else None,
|
|
1126
|
-
num_workers=num_workers,
|
|
1127
|
-
)
|
|
1128
|
-
should_log_valid = (epoch + 1) % log_interval == 0 or (
|
|
1204
|
+
should_eval_valid = (epoch + 1) % log_interval == 0 or (
|
|
1129
1205
|
epoch + 1
|
|
1130
1206
|
) == epochs
|
|
1131
|
-
|
|
1207
|
+
val_metrics = None
|
|
1208
|
+
if should_eval_valid:
|
|
1209
|
+
self.callbacks.on_validation_begin()
|
|
1210
|
+
val_metrics = self.evaluate(
|
|
1211
|
+
valid_loader,
|
|
1212
|
+
user_ids=valid_user_ids if self.needs_user_ids else None,
|
|
1213
|
+
num_workers=num_workers,
|
|
1214
|
+
)
|
|
1132
1215
|
display_metrics_table(
|
|
1133
1216
|
epoch=epoch + 1,
|
|
1134
1217
|
epochs=epochs,
|
|
@@ -1142,23 +1225,24 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1142
1225
|
is_main_process=self.is_main_process,
|
|
1143
1226
|
colorize=lambda s: colorize(" " + s, color="cyan"),
|
|
1144
1227
|
)
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
|
|
1228
|
+
self.callbacks.on_validation_end()
|
|
1229
|
+
if val_metrics and self.training_logger:
|
|
1230
|
+
self.training_logger.log_metrics(
|
|
1231
|
+
val_metrics, step=epoch + 1, split="valid"
|
|
1232
|
+
)
|
|
1150
1233
|
if not val_metrics:
|
|
1151
|
-
if self.is_main_process:
|
|
1234
|
+
if should_eval_valid and self.is_main_process:
|
|
1152
1235
|
logging.info(
|
|
1153
1236
|
colorize(
|
|
1154
1237
|
"Warning: No validation metrics computed. Skipping validation for this epoch.",
|
|
1155
1238
|
color="yellow",
|
|
1156
1239
|
)
|
|
1157
1240
|
)
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1241
|
+
epoch_logs = {**train_log_payload}
|
|
1242
|
+
else:
|
|
1243
|
+
epoch_logs = {**train_log_payload}
|
|
1244
|
+
for k, v in val_metrics.items():
|
|
1245
|
+
epoch_logs[f"val_{k}"] = v
|
|
1162
1246
|
else:
|
|
1163
1247
|
epoch_logs = {**train_log_payload}
|
|
1164
1248
|
if self.is_main_process:
|
|
@@ -1249,7 +1333,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1249
1333
|
for batch_index, batch_data in batch_iter:
|
|
1250
1334
|
batch_dict = batch_to_dict(batch_data)
|
|
1251
1335
|
X_input, y_true = self.get_input(batch_dict, require_labels=True)
|
|
1252
|
-
# call via __call__ so DDP hooks run
|
|
1336
|
+
# call via __call__ so DDP hooks run
|
|
1253
1337
|
y_pred = model(X_input) # type: ignore
|
|
1254
1338
|
|
|
1255
1339
|
loss = self.compute_loss(y_pred, y_true)
|
|
@@ -1340,6 +1424,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1340
1424
|
target_names=self.target_columns,
|
|
1341
1425
|
task_specific_metrics=self.task_specific_metrics,
|
|
1342
1426
|
user_ids=combined_user_ids,
|
|
1427
|
+
ignore_label=self.ignore_label,
|
|
1343
1428
|
)
|
|
1344
1429
|
return avg_loss, metrics_dict
|
|
1345
1430
|
return avg_loss
|
|
@@ -1387,6 +1472,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1387
1472
|
sampler=valid_sampler,
|
|
1388
1473
|
collate_fn=collate_fn,
|
|
1389
1474
|
num_workers=num_workers,
|
|
1475
|
+
pin_memory=self.device.type == "cuda",
|
|
1476
|
+
persistent_workers=num_workers > 0,
|
|
1390
1477
|
)
|
|
1391
1478
|
valid_user_ids = None
|
|
1392
1479
|
if needs_user_ids:
|
|
@@ -1532,6 +1619,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1532
1619
|
target_names=self.target_columns,
|
|
1533
1620
|
task_specific_metrics=self.task_specific_metrics,
|
|
1534
1621
|
user_ids=final_user_ids,
|
|
1622
|
+
ignore_label=self.ignore_label,
|
|
1535
1623
|
)
|
|
1536
1624
|
return metrics_dict
|
|
1537
1625
|
|
|
@@ -1548,7 +1636,6 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1548
1636
|
num_workers: int = 0,
|
|
1549
1637
|
) -> pd.DataFrame | np.ndarray | Path | None:
|
|
1550
1638
|
"""
|
|
1551
|
-
Note: predict does not support distributed mode currently, consider it as a single-process operation.
|
|
1552
1639
|
Make predictions on the given data.
|
|
1553
1640
|
|
|
1554
1641
|
Args:
|
|
@@ -1561,6 +1648,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1561
1648
|
return_dataframe: Whether to return predictions as a pandas DataFrame; if False, returns a NumPy array.
|
|
1562
1649
|
stream_chunk_size: Number of rows per chunk when using streaming mode for large datasets.
|
|
1563
1650
|
num_workers: DataLoader worker count.
|
|
1651
|
+
|
|
1652
|
+
Note:
|
|
1653
|
+
predict does not support distributed mode currently, consider it as a single-process operation.
|
|
1564
1654
|
"""
|
|
1565
1655
|
self.eval()
|
|
1566
1656
|
# Use prediction-time id_columns if provided, otherwise fall back to model's id_columns
|
|
@@ -1745,6 +1835,21 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1745
1835
|
return_dataframe: bool,
|
|
1746
1836
|
id_columns: list[str] | None = None,
|
|
1747
1837
|
):
|
|
1838
|
+
"""
|
|
1839
|
+
Make predictions on the given data using streaming mode for large datasets.
|
|
1840
|
+
|
|
1841
|
+
Args:
|
|
1842
|
+
data: Input data for prediction (file path, dict, DataFrame, or DataLoader).
|
|
1843
|
+
batch_size: Batch size for prediction.
|
|
1844
|
+
save_path: Path to save predictions.
|
|
1845
|
+
save_format: Format to save predictions ('csv' or 'parquet').
|
|
1846
|
+
include_ids: Whether to include ID columns in the output.
|
|
1847
|
+
stream_chunk_size: Number of rows per chunk when using streaming mode.
|
|
1848
|
+
return_dataframe: Whether to return predictions as a pandas DataFrame.
|
|
1849
|
+
id_columns: Column name(s) to use as IDs; if None, uses model's id_columns.
|
|
1850
|
+
Note:
|
|
1851
|
+
This method uses streaming writes to handle large datasets without loading all data into memory.
|
|
1852
|
+
"""
|
|
1748
1853
|
if isinstance(data, (str, os.PathLike)):
|
|
1749
1854
|
rec_loader = RecDataLoader(
|
|
1750
1855
|
dense_features=self.dense_features,
|
|
@@ -1787,8 +1892,6 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1787
1892
|
"Results will be collected in memory and saved at the end. Use csv or parquet for true streaming."
|
|
1788
1893
|
)
|
|
1789
1894
|
|
|
1790
|
-
from nextrec.utils.data import FILE_FORMAT_CONFIG
|
|
1791
|
-
|
|
1792
1895
|
suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
|
|
1793
1896
|
|
|
1794
1897
|
target_path = get_save_path(
|
|
@@ -1900,6 +2003,14 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1900
2003
|
add_timestamp: bool | None = None,
|
|
1901
2004
|
verbose: bool = True,
|
|
1902
2005
|
):
|
|
2006
|
+
"""
|
|
2007
|
+
Save the model state and features configuration to disk.
|
|
2008
|
+
|
|
2009
|
+
Args:
|
|
2010
|
+
save_path: Path to save the model; if None, saves to the session's model directory.
|
|
2011
|
+
add_timestamp: Whether to add a timestamp to the filename; if None, defaults to True.
|
|
2012
|
+
verbose: Whether to log the save location.
|
|
2013
|
+
"""
|
|
1903
2014
|
add_timestamp = False if add_timestamp is None else add_timestamp
|
|
1904
2015
|
target_path = get_save_path(
|
|
1905
2016
|
path=save_path,
|
|
@@ -1942,6 +2053,14 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1942
2053
|
map_location: str | torch.device | None = "cpu",
|
|
1943
2054
|
verbose: bool = True,
|
|
1944
2055
|
):
|
|
2056
|
+
"""
|
|
2057
|
+
Load the model state and features configuration from disk.
|
|
2058
|
+
|
|
2059
|
+
Args:
|
|
2060
|
+
save_path: Path to load the model from; can be a directory or a specific .pt file.
|
|
2061
|
+
map_location: Device mapping for loading the model (e.g., 'cpu', 'cuda:0').
|
|
2062
|
+
verbose: Whether to log the load location.
|
|
2063
|
+
"""
|
|
1945
2064
|
self.to(self.device)
|
|
1946
2065
|
base_path = Path(save_path)
|
|
1947
2066
|
if base_path.is_dir():
|
|
@@ -2008,6 +2127,13 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
2008
2127
|
"""
|
|
2009
2128
|
Load a model from a checkpoint path. The checkpoint path should contain:
|
|
2010
2129
|
a .pt file and a features_config.pkl file.
|
|
2130
|
+
|
|
2131
|
+
Args:
|
|
2132
|
+
checkpoint_path: Path to the checkpoint directory or specific .pt file.
|
|
2133
|
+
map_location: Device mapping for loading the model (e.g., 'cpu', 'cuda:0').
|
|
2134
|
+
device: Device to place the model on after loading.
|
|
2135
|
+
session_id: Optional session ID for the model.
|
|
2136
|
+
**kwargs: Additional keyword arguments to pass to the model constructor.
|
|
2011
2137
|
"""
|
|
2012
2138
|
base_path = Path(checkpoint_path)
|
|
2013
2139
|
verbose = kwargs.pop("verbose", True)
|
|
@@ -2127,6 +2253,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2127
2253
|
target=target,
|
|
2128
2254
|
id_columns=id_columns,
|
|
2129
2255
|
task=task,
|
|
2256
|
+
training_mode=training_mode,
|
|
2130
2257
|
device=device,
|
|
2131
2258
|
embedding_l1_reg=embedding_l1_reg,
|
|
2132
2259
|
dense_l1_reg=dense_l1_reg,
|
|
@@ -2149,10 +2276,13 @@ class BaseMatchModel(BaseModel):
|
|
|
2149
2276
|
self.item_sparse_features = item_sparse_features
|
|
2150
2277
|
self.item_sequence_features = item_sequence_features
|
|
2151
2278
|
|
|
2152
|
-
self.training_mode = training_mode
|
|
2153
2279
|
self.num_negative_samples = num_negative_samples
|
|
2154
2280
|
self.temperature = temperature
|
|
2155
2281
|
self.similarity_metric = similarity_metric
|
|
2282
|
+
if self.training_mode not in self.support_training_modes:
|
|
2283
|
+
raise ValueError(
|
|
2284
|
+
f"{self.model_name.upper()} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
|
|
2285
|
+
)
|
|
2156
2286
|
self.user_features_all = (
|
|
2157
2287
|
self.user_dense_features
|
|
2158
2288
|
+ self.user_sparse_features
|
|
@@ -2201,11 +2331,6 @@ class BaseMatchModel(BaseModel):
|
|
|
2201
2331
|
loss_params: Parameters for the loss function(s). e.g., {'reduction': 'mean'}.
|
|
2202
2332
|
loss_weights: Weights for the loss function(s). e.g., 1.0 or [0.7, 0.3].
|
|
2203
2333
|
"""
|
|
2204
|
-
if self.training_mode not in self.support_training_modes:
|
|
2205
|
-
raise ValueError(
|
|
2206
|
-
f"{self.model_name.upper()} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
|
|
2207
|
-
)
|
|
2208
|
-
|
|
2209
2334
|
default_loss_by_mode = {
|
|
2210
2335
|
"pointwise": "bce",
|
|
2211
2336
|
"pairwise": "bpr",
|