nextrec 0.4.24__py3-none-any.whl → 0.4.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/model.py +175 -58
- nextrec/basic/summary.py +58 -0
- nextrec/cli.py +13 -0
- nextrec/data/data_processing.py +3 -9
- nextrec/data/dataloader.py +25 -2
- nextrec/data/preprocessor.py +283 -36
- nextrec/utils/config.py +2 -0
- nextrec/utils/model.py +14 -70
- nextrec/utils/torch_utils.py +11 -0
- {nextrec-0.4.24.dist-info → nextrec-0.4.25.dist-info}/METADATA +4 -4
- {nextrec-0.4.24.dist-info → nextrec-0.4.25.dist-info}/RECORD +15 -15
- {nextrec-0.4.24.dist-info → nextrec-0.4.25.dist-info}/WHEEL +0 -0
- {nextrec-0.4.24.dist-info → nextrec-0.4.25.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.24.dist-info → nextrec-0.4.25.dist-info}/licenses/LICENSE +0 -0
nextrec/__version__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.4.
|
|
1
|
+
__version__ = "0.4.25"
|
nextrec/basic/model.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 31/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -88,9 +88,8 @@ from nextrec.utils.config import safe_value
|
|
|
88
88
|
from nextrec.utils.model import (
|
|
89
89
|
compute_ranking_loss,
|
|
90
90
|
get_loss_list,
|
|
91
|
-
resolve_loss_weights,
|
|
92
|
-
get_training_modes,
|
|
93
91
|
)
|
|
92
|
+
|
|
94
93
|
from nextrec.utils.types import (
|
|
95
94
|
LossName,
|
|
96
95
|
OptimizerName,
|
|
@@ -100,6 +99,7 @@ from nextrec.utils.types import (
|
|
|
100
99
|
MetricsName,
|
|
101
100
|
)
|
|
102
101
|
|
|
102
|
+
from nextrec.utils.data import FILE_FORMAT_CONFIG
|
|
103
103
|
|
|
104
104
|
class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
105
105
|
@property
|
|
@@ -110,6 +110,30 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
110
110
|
def default_task(self) -> TaskTypeName | list[TaskTypeName]:
|
|
111
111
|
raise NotImplementedError
|
|
112
112
|
|
|
113
|
+
@property
|
|
114
|
+
def training_mode(self) -> TrainingModeName | list[TrainingModeName]:
|
|
115
|
+
if self.nums_task > 1:
|
|
116
|
+
return self.training_modes
|
|
117
|
+
return self.training_modes[0] if self.training_modes else "pointwise"
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@training_mode.setter
|
|
121
|
+
def training_mode(self, training_mode: TrainingModeName | list[TrainingModeName]):
|
|
122
|
+
valid_modes = {"pointwise", "pairwise", "listwise"}
|
|
123
|
+
if isinstance(training_mode, list):
|
|
124
|
+
training_modes = list(training_mode)
|
|
125
|
+
if len(training_modes) != self.nums_task:
|
|
126
|
+
raise ValueError(
|
|
127
|
+
"[BaseModel-init Error] training_mode list length must match number of tasks."
|
|
128
|
+
)
|
|
129
|
+
else:
|
|
130
|
+
training_modes = [training_mode] * self.nums_task
|
|
131
|
+
if any(mode not in valid_modes for mode in training_modes):
|
|
132
|
+
raise ValueError(
|
|
133
|
+
"[BaseModel-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
|
|
134
|
+
)
|
|
135
|
+
self.training_modes = list(training_modes)
|
|
136
|
+
|
|
113
137
|
def __init__(
|
|
114
138
|
self,
|
|
115
139
|
dense_features: list[DenseFeature] | None = None,
|
|
@@ -193,10 +217,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
193
217
|
|
|
194
218
|
self.task = task or self.default_task
|
|
195
219
|
self.nums_task = len(self.task) if isinstance(self.task, list) else 1
|
|
196
|
-
|
|
197
|
-
self.training_mode =
|
|
198
|
-
self.training_modes if self.nums_task > 1 else self.training_modes[0]
|
|
199
|
-
)
|
|
220
|
+
|
|
221
|
+
self.training_mode = training_mode
|
|
200
222
|
|
|
201
223
|
self.embedding_l1_reg = embedding_l1_reg
|
|
202
224
|
self.dense_l1_reg = dense_l1_reg
|
|
@@ -215,6 +237,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
215
237
|
|
|
216
238
|
self.train_data_summary = None
|
|
217
239
|
self.valid_data_summary = None
|
|
240
|
+
self.note = None
|
|
218
241
|
|
|
219
242
|
def register_regularization_weights(
|
|
220
243
|
self,
|
|
@@ -222,6 +245,15 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
222
245
|
exclude_modules: list[str] | None = None,
|
|
223
246
|
include_modules: list[str] | None = None,
|
|
224
247
|
):
|
|
248
|
+
"""
|
|
249
|
+
Register parameters for regularization.
|
|
250
|
+
By default, all nn.Linear weights (excluding those in BatchNorm/Dropout layers) and embedding weights under `embedding_attr` are registered.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
embedding_attr: Attribute name of the embedding layer/module.
|
|
254
|
+
exclude_modules: List of module name substrings to exclude from regularization.
|
|
255
|
+
include_modules: List of module name substrings to include for regularization. If provided, only modules containing these substrings are included.
|
|
256
|
+
"""
|
|
225
257
|
exclude_modules = exclude_modules or []
|
|
226
258
|
include_modules = include_modules or []
|
|
227
259
|
embedding_layer = getattr(self, embedding_attr, None)
|
|
@@ -268,6 +300,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
268
300
|
existing_reg_ids.add(id(module.weight))
|
|
269
301
|
|
|
270
302
|
def add_reg_loss(self) -> torch.Tensor:
|
|
303
|
+
"""
|
|
304
|
+
Compute the regularization loss based on registered parameters and their respective regularization strengths.
|
|
305
|
+
"""
|
|
271
306
|
reg_loss = torch.tensor(0.0, device=self.device)
|
|
272
307
|
|
|
273
308
|
if self.embedding_l1_reg > 0:
|
|
@@ -289,9 +324,25 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
289
324
|
)
|
|
290
325
|
return reg_loss
|
|
291
326
|
|
|
327
|
+
# todo: support build pairwise/listwise label in input
|
|
292
328
|
def get_input(self, input_data: dict, require_labels: bool = True):
|
|
329
|
+
"""
|
|
330
|
+
Prepare unified input features and labels from the given input data.
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
input_data: Input data dictionary containing 'features' and optionally 'labels', e.g., {'features': {'feat1': [...], 'feat2': [...]}, 'labels': {'label': [...]}}.
|
|
335
|
+
require_labels: Whether labels are required in the input data. Default is True: for training and evaluation with labels.
|
|
336
|
+
|
|
337
|
+
Note:
|
|
338
|
+
target tensor shape will always be (batch_size, num_targets)
|
|
339
|
+
"""
|
|
293
340
|
feature_source = input_data.get("features", {})
|
|
341
|
+
# todo: pairwise/listwise label support
|
|
342
|
+
# "labels": {...} should contain pointwise/pair index/list index/ relevance scores
|
|
343
|
+
# now only have pointwise label support
|
|
294
344
|
label_source = input_data.get("labels")
|
|
345
|
+
|
|
295
346
|
X_input = {}
|
|
296
347
|
for feature in self.all_features:
|
|
297
348
|
if feature.name not in feature_source:
|
|
@@ -307,13 +358,14 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
307
358
|
device=self.device,
|
|
308
359
|
)
|
|
309
360
|
y = None
|
|
361
|
+
# if need labels: training or eval with labels
|
|
310
362
|
if len(self.target_columns) > 0 and (
|
|
311
363
|
require_labels
|
|
312
364
|
or (
|
|
313
365
|
label_source
|
|
314
366
|
and any(name in label_source for name in self.target_columns)
|
|
315
367
|
)
|
|
316
|
-
):
|
|
368
|
+
):
|
|
317
369
|
target_tensors = []
|
|
318
370
|
for target_name in self.target_columns:
|
|
319
371
|
if label_source is None or target_name not in label_source:
|
|
@@ -358,6 +410,10 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
358
410
|
This function will split training data into training and validation sets when:
|
|
359
411
|
1. valid_data is None;
|
|
360
412
|
2. valid_split is provided.
|
|
413
|
+
|
|
414
|
+
Returns:
|
|
415
|
+
train_loader: DataLoader for training data.
|
|
416
|
+
valid_split_data: Validation data dict/dataframe split from training data.
|
|
361
417
|
"""
|
|
362
418
|
if not (0 < valid_split < 1):
|
|
363
419
|
raise ValueError(
|
|
@@ -375,7 +431,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
375
431
|
)
|
|
376
432
|
else:
|
|
377
433
|
raise TypeError(
|
|
378
|
-
f"[BaseModel-validation Error] If you want to use valid_split, train_data must be
|
|
434
|
+
f"[BaseModel-validation Error] If you want to use valid_split, train_data must be DataFrame or a dict, now got {type(train_data)}"
|
|
379
435
|
)
|
|
380
436
|
rng = np.random.default_rng(42)
|
|
381
437
|
indices = rng.permutation(total_length)
|
|
@@ -426,7 +482,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
426
482
|
Args:
|
|
427
483
|
optimizer: Optimizer name or instance. e.g., 'adam', 'sgd', or torch.optim.Adam().
|
|
428
484
|
optimizer_params: Optimizer parameters. e.g., {'lr': 1e-3, 'weight_decay': 1e-5}.
|
|
429
|
-
scheduler: Learning rate scheduler name or instance. e.g., '
|
|
485
|
+
scheduler: Learning rate scheduler name or instance. e.g., 'step', 'cosine', or torch.optim.lr_scheduler.StepLR().
|
|
430
486
|
scheduler_params: Scheduler parameters. e.g., {'step_size': 10, 'gamma': 0.1}.
|
|
431
487
|
loss: Loss function name, instance, or list for multi-task. e.g., 'bce', 'mse', or torch.nn.BCELoss(), you can also use custom loss functions.
|
|
432
488
|
loss_params: Loss function parameters, or list for multi-task. e.g., {'weight': tensor([0.25, 0.75])}.
|
|
@@ -435,36 +491,31 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
435
491
|
ignore_label: Label value to ignore when computing loss. Use this to skip gradients for unknown labels.
|
|
436
492
|
"""
|
|
437
493
|
self.ignore_label = ignore_label
|
|
438
|
-
default_losses = {
|
|
439
|
-
"pointwise": "bce",
|
|
440
|
-
"pairwise": "bpr",
|
|
441
|
-
"listwise": "listnet",
|
|
442
|
-
}
|
|
443
494
|
loss_list = get_loss_list(
|
|
444
|
-
loss, self.training_modes, self.nums_task
|
|
495
|
+
loss, self.training_modes, self.nums_task
|
|
445
496
|
)
|
|
446
|
-
|
|
447
|
-
|
|
497
|
+
|
|
498
|
+
self.loss_params = {} if loss_params is None else loss_params
|
|
499
|
+
self.optimizer_params = optimizer_params or {}
|
|
500
|
+
self.scheduler_params = scheduler_params or {}
|
|
501
|
+
|
|
448
502
|
self.optimizer_name = (
|
|
449
503
|
optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
450
504
|
)
|
|
451
|
-
self.optimizer_params = optimizer_params
|
|
452
505
|
self.optimizer_fn = get_optimizer(
|
|
453
506
|
optimizer=optimizer,
|
|
454
507
|
params=self.parameters(),
|
|
455
|
-
**optimizer_params,
|
|
508
|
+
**self.optimizer_params,
|
|
456
509
|
)
|
|
457
510
|
|
|
458
|
-
scheduler_params = scheduler_params or {}
|
|
459
511
|
if scheduler is None:
|
|
460
512
|
self.scheduler_name = None
|
|
461
513
|
elif isinstance(scheduler, str):
|
|
462
514
|
self.scheduler_name = scheduler
|
|
463
515
|
else:
|
|
464
516
|
self.scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__) # type: ignore
|
|
465
|
-
self.scheduler_params = scheduler_params
|
|
466
517
|
self.scheduler_fn = (
|
|
467
|
-
get_scheduler(scheduler, self.optimizer_fn, **scheduler_params)
|
|
518
|
+
get_scheduler(scheduler, self.optimizer_fn, **self.scheduler_params)
|
|
468
519
|
if scheduler
|
|
469
520
|
else None
|
|
470
521
|
)
|
|
@@ -482,35 +533,54 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
482
533
|
for i in range(self.nums_task)
|
|
483
534
|
]
|
|
484
535
|
|
|
536
|
+
# loss weighting (grad norm or fixed weights)
|
|
485
537
|
self.grad_norm = None
|
|
486
538
|
self.grad_norm_shared_params = None
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
nums_task=self.nums_task, device=self.device
|
|
494
|
-
)
|
|
495
|
-
self.loss_weights = None
|
|
496
|
-
elif (
|
|
497
|
-
isinstance(loss_weights, dict) and loss_weights.get("method") == "grad_norm"
|
|
498
|
-
):
|
|
539
|
+
is_grad_norm = (
|
|
540
|
+
loss_weights == "grad_norm"
|
|
541
|
+
or isinstance(loss_weights, dict)
|
|
542
|
+
and loss_weights.get("method") == "grad_norm"
|
|
543
|
+
)
|
|
544
|
+
if is_grad_norm:
|
|
499
545
|
if self.nums_task == 1:
|
|
500
546
|
raise ValueError(
|
|
501
547
|
"[BaseModel-compile Error] GradNorm requires multi-task setup."
|
|
502
548
|
)
|
|
503
|
-
grad_norm_params = dict(loss_weights)
|
|
549
|
+
grad_norm_params = dict(loss_weights) if isinstance(loss_weights, dict) else {}
|
|
504
550
|
grad_norm_params.pop("method", None)
|
|
505
551
|
self.grad_norm = GradNormLossWeighting(
|
|
506
552
|
nums_task=self.nums_task, device=self.device, **grad_norm_params
|
|
507
553
|
)
|
|
508
554
|
self.loss_weights = None
|
|
555
|
+
elif loss_weights is None:
|
|
556
|
+
self.loss_weights = None
|
|
557
|
+
elif self.nums_task == 1:
|
|
558
|
+
if isinstance(loss_weights, (list, tuple)):
|
|
559
|
+
if len(loss_weights) != 1:
|
|
560
|
+
raise ValueError(
|
|
561
|
+
"[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup."
|
|
562
|
+
)
|
|
563
|
+
loss_weights = loss_weights[0]
|
|
564
|
+
self.loss_weights = [float(loss_weights)]
|
|
565
|
+
elif isinstance(loss_weights, (int, float)):
|
|
566
|
+
self.loss_weights = [float(loss_weights)] * self.nums_task
|
|
567
|
+
elif isinstance(loss_weights, (list, tuple)):
|
|
568
|
+
weights = [float(w) for w in loss_weights]
|
|
569
|
+
if len(weights) != self.nums_task:
|
|
570
|
+
raise ValueError(
|
|
571
|
+
f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task})."
|
|
572
|
+
)
|
|
573
|
+
self.loss_weights = weights
|
|
509
574
|
else:
|
|
510
|
-
|
|
575
|
+
raise TypeError(
|
|
576
|
+
f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}"
|
|
577
|
+
)
|
|
511
578
|
self.compiled = True
|
|
512
579
|
|
|
513
580
|
def compute_loss(self, y_pred, y_true):
|
|
581
|
+
"""
|
|
582
|
+
Compute the loss between predictions and ground truth labels, with loss weighting and ignore_label handling
|
|
583
|
+
"""
|
|
514
584
|
if y_true is None:
|
|
515
585
|
raise ValueError(
|
|
516
586
|
"[BaseModel-compute_loss Error] Ground truth labels (y_true) are required."
|
|
@@ -522,13 +592,11 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
522
592
|
y_pred = y_pred.view(-1, 1)
|
|
523
593
|
if y_true.dim() == 1:
|
|
524
594
|
y_true = y_true.view(-1, 1)
|
|
525
|
-
if y_pred.shape != y_true.shape:
|
|
526
|
-
raise ValueError(
|
|
527
|
-
f"[BaseModel-compute_loss Error] Shape mismatch: {y_pred.shape} vs {y_true.shape}"
|
|
528
|
-
)
|
|
529
595
|
|
|
530
596
|
loss_fn = self.loss_fn[0]
|
|
531
|
-
|
|
597
|
+
|
|
598
|
+
# mask ignored labels
|
|
599
|
+
# we don't suggest using ignore_label for single task training
|
|
532
600
|
if self.ignore_label is not None:
|
|
533
601
|
valid_mask = y_true != self.ignore_label
|
|
534
602
|
if valid_mask.dim() > 1:
|
|
@@ -559,9 +627,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
559
627
|
loss *= self.loss_weights[0]
|
|
560
628
|
return loss
|
|
561
629
|
|
|
562
|
-
# multi-task
|
|
563
|
-
if y_pred.shape != y_true.shape:
|
|
564
|
-
raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
|
|
630
|
+
# multi-task: slice predictions and labels per task
|
|
565
631
|
slices = (
|
|
566
632
|
self.prediction_layer.task_slices # type: ignore
|
|
567
633
|
if hasattr(self, "prediction_layer")
|
|
@@ -593,9 +659,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
593
659
|
)
|
|
594
660
|
else:
|
|
595
661
|
task_loss = self.loss_fn[i](y_pred_i, y_true_i)
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
662
|
+
# task_loss = normalize_task_loss(
|
|
663
|
+
# task_loss, valid_count, total_count
|
|
664
|
+
# ) # normalize by valid samples to avoid loss scale issues
|
|
599
665
|
task_losses.append(task_loss)
|
|
600
666
|
|
|
601
667
|
if self.grad_norm is not None:
|
|
@@ -624,6 +690,16 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
624
690
|
):
|
|
625
691
|
"""
|
|
626
692
|
Prepare a DataLoader from input data. Only used when input data is not a DataLoader.
|
|
693
|
+
|
|
694
|
+
Args:
|
|
695
|
+
data: Input data (dict/df/DataLoader).
|
|
696
|
+
batch_size: Batch size.
|
|
697
|
+
shuffle: Whether to shuffle the data (ignored when a sampler is provided).
|
|
698
|
+
num_workers: Number of DataLoader workers.
|
|
699
|
+
sampler: Optional sampler for DataLoader.
|
|
700
|
+
return_dataset: Whether to return the tensor dataset along with the DataLoader, used for valid data
|
|
701
|
+
Returns:
|
|
702
|
+
DataLoader (and tensor dataset if return_dataset is True).
|
|
627
703
|
"""
|
|
628
704
|
if isinstance(data, DataLoader):
|
|
629
705
|
return (data, None) if return_dataset else data
|
|
@@ -676,6 +752,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
676
752
|
swanlab_kwargs: dict | None = None,
|
|
677
753
|
auto_ddp_sampler: bool = True,
|
|
678
754
|
log_interval: int = 1,
|
|
755
|
+
note: str | None = None,
|
|
679
756
|
summary_sections: (
|
|
680
757
|
list[Literal["feature", "model", "train", "data"]] | None
|
|
681
758
|
) = None,
|
|
@@ -707,6 +784,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
707
784
|
swanlab_kwargs: Optional kwargs for swanlab.init(...).
|
|
708
785
|
auto_ddp_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
|
|
709
786
|
log_interval: Log validation metrics every N epochs (still computes metrics each epoch).
|
|
787
|
+
note: Optional note for the training run.
|
|
710
788
|
summary_sections: Optional summary sections to print. Choose from
|
|
711
789
|
["feature", "model", "train", "data"]. Defaults to all.
|
|
712
790
|
|
|
@@ -770,11 +848,13 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
770
848
|
self.metrics_sample_limit = (
|
|
771
849
|
None if metrics_sample_limit is None else int(metrics_sample_limit)
|
|
772
850
|
)
|
|
851
|
+
self.note = note
|
|
773
852
|
|
|
774
853
|
training_config = {}
|
|
775
854
|
if self.is_main_process:
|
|
776
855
|
training_config = {
|
|
777
856
|
"model_name": getattr(self, "model_name", self.__class__.__name__),
|
|
857
|
+
"note": self.note,
|
|
778
858
|
"task": self.task,
|
|
779
859
|
"target_columns": self.target_columns,
|
|
780
860
|
"batch_size": batch_size,
|
|
@@ -1253,7 +1333,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1253
1333
|
for batch_index, batch_data in batch_iter:
|
|
1254
1334
|
batch_dict = batch_to_dict(batch_data)
|
|
1255
1335
|
X_input, y_true = self.get_input(batch_dict, require_labels=True)
|
|
1256
|
-
# call via __call__ so DDP hooks run
|
|
1336
|
+
# call via __call__ so DDP hooks run
|
|
1257
1337
|
y_pred = model(X_input) # type: ignore
|
|
1258
1338
|
|
|
1259
1339
|
loss = self.compute_loss(y_pred, y_true)
|
|
@@ -1556,7 +1636,6 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1556
1636
|
num_workers: int = 0,
|
|
1557
1637
|
) -> pd.DataFrame | np.ndarray | Path | None:
|
|
1558
1638
|
"""
|
|
1559
|
-
Note: predict does not support distributed mode currently, consider it as a single-process operation.
|
|
1560
1639
|
Make predictions on the given data.
|
|
1561
1640
|
|
|
1562
1641
|
Args:
|
|
@@ -1569,6 +1648,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1569
1648
|
return_dataframe: Whether to return predictions as a pandas DataFrame; if False, returns a NumPy array.
|
|
1570
1649
|
stream_chunk_size: Number of rows per chunk when using streaming mode for large datasets.
|
|
1571
1650
|
num_workers: DataLoader worker count.
|
|
1651
|
+
|
|
1652
|
+
Note:
|
|
1653
|
+
predict does not support distributed mode currently, consider it as a single-process operation.
|
|
1572
1654
|
"""
|
|
1573
1655
|
self.eval()
|
|
1574
1656
|
# Use prediction-time id_columns if provided, otherwise fall back to model's id_columns
|
|
@@ -1753,6 +1835,21 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1753
1835
|
return_dataframe: bool,
|
|
1754
1836
|
id_columns: list[str] | None = None,
|
|
1755
1837
|
):
|
|
1838
|
+
"""
|
|
1839
|
+
Make predictions on the given data using streaming mode for large datasets.
|
|
1840
|
+
|
|
1841
|
+
Args:
|
|
1842
|
+
data: Input data for prediction (file path, dict, DataFrame, or DataLoader).
|
|
1843
|
+
batch_size: Batch size for prediction.
|
|
1844
|
+
save_path: Path to save predictions.
|
|
1845
|
+
save_format: Format to save predictions ('csv' or 'parquet').
|
|
1846
|
+
include_ids: Whether to include ID columns in the output.
|
|
1847
|
+
stream_chunk_size: Number of rows per chunk when using streaming mode.
|
|
1848
|
+
return_dataframe: Whether to return predictions as a pandas DataFrame.
|
|
1849
|
+
id_columns: Column name(s) to use as IDs; if None, uses model's id_columns.
|
|
1850
|
+
Note:
|
|
1851
|
+
This method uses streaming writes to handle large datasets without loading all data into memory.
|
|
1852
|
+
"""
|
|
1756
1853
|
if isinstance(data, (str, os.PathLike)):
|
|
1757
1854
|
rec_loader = RecDataLoader(
|
|
1758
1855
|
dense_features=self.dense_features,
|
|
@@ -1795,8 +1892,6 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1795
1892
|
"Results will be collected in memory and saved at the end. Use csv or parquet for true streaming."
|
|
1796
1893
|
)
|
|
1797
1894
|
|
|
1798
|
-
from nextrec.utils.data import FILE_FORMAT_CONFIG
|
|
1799
|
-
|
|
1800
1895
|
suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
|
|
1801
1896
|
|
|
1802
1897
|
target_path = get_save_path(
|
|
@@ -1908,6 +2003,14 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1908
2003
|
add_timestamp: bool | None = None,
|
|
1909
2004
|
verbose: bool = True,
|
|
1910
2005
|
):
|
|
2006
|
+
"""
|
|
2007
|
+
Save the model state and features configuration to disk.
|
|
2008
|
+
|
|
2009
|
+
Args:
|
|
2010
|
+
save_path: Path to save the model; if None, saves to the session's model directory.
|
|
2011
|
+
add_timestamp: Whether to add a timestamp to the filename; if None, defaults to True.
|
|
2012
|
+
verbose: Whether to log the save location.
|
|
2013
|
+
"""
|
|
1911
2014
|
add_timestamp = False if add_timestamp is None else add_timestamp
|
|
1912
2015
|
target_path = get_save_path(
|
|
1913
2016
|
path=save_path,
|
|
@@ -1950,6 +2053,14 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
1950
2053
|
map_location: str | torch.device | None = "cpu",
|
|
1951
2054
|
verbose: bool = True,
|
|
1952
2055
|
):
|
|
2056
|
+
"""
|
|
2057
|
+
Load the model state and features configuration from disk.
|
|
2058
|
+
|
|
2059
|
+
Args:
|
|
2060
|
+
save_path: Path to load the model from; can be a directory or a specific .pt file.
|
|
2061
|
+
map_location: Device mapping for loading the model (e.g., 'cpu', 'cuda:0').
|
|
2062
|
+
verbose: Whether to log the load location.
|
|
2063
|
+
"""
|
|
1953
2064
|
self.to(self.device)
|
|
1954
2065
|
base_path = Path(save_path)
|
|
1955
2066
|
if base_path.is_dir():
|
|
@@ -2016,6 +2127,13 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
2016
2127
|
"""
|
|
2017
2128
|
Load a model from a checkpoint path. The checkpoint path should contain:
|
|
2018
2129
|
a .pt file and a features_config.pkl file.
|
|
2130
|
+
|
|
2131
|
+
Args:
|
|
2132
|
+
checkpoint_path: Path to the checkpoint directory or specific .pt file.
|
|
2133
|
+
map_location: Device mapping for loading the model (e.g., 'cpu', 'cuda:0').
|
|
2134
|
+
device: Device to place the model on after loading.
|
|
2135
|
+
session_id: Optional session ID for the model.
|
|
2136
|
+
**kwargs: Additional keyword arguments to pass to the model constructor.
|
|
2019
2137
|
"""
|
|
2020
2138
|
base_path = Path(checkpoint_path)
|
|
2021
2139
|
verbose = kwargs.pop("verbose", True)
|
|
@@ -2135,6 +2253,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2135
2253
|
target=target,
|
|
2136
2254
|
id_columns=id_columns,
|
|
2137
2255
|
task=task,
|
|
2256
|
+
training_mode=training_mode,
|
|
2138
2257
|
device=device,
|
|
2139
2258
|
embedding_l1_reg=embedding_l1_reg,
|
|
2140
2259
|
dense_l1_reg=dense_l1_reg,
|
|
@@ -2157,10 +2276,13 @@ class BaseMatchModel(BaseModel):
|
|
|
2157
2276
|
self.item_sparse_features = item_sparse_features
|
|
2158
2277
|
self.item_sequence_features = item_sequence_features
|
|
2159
2278
|
|
|
2160
|
-
self.training_mode = training_mode
|
|
2161
2279
|
self.num_negative_samples = num_negative_samples
|
|
2162
2280
|
self.temperature = temperature
|
|
2163
2281
|
self.similarity_metric = similarity_metric
|
|
2282
|
+
if self.training_mode not in self.support_training_modes:
|
|
2283
|
+
raise ValueError(
|
|
2284
|
+
f"{self.model_name.upper()} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
|
|
2285
|
+
)
|
|
2164
2286
|
self.user_features_all = (
|
|
2165
2287
|
self.user_dense_features
|
|
2166
2288
|
+ self.user_sparse_features
|
|
@@ -2209,11 +2331,6 @@ class BaseMatchModel(BaseModel):
|
|
|
2209
2331
|
loss_params: Parameters for the loss function(s). e.g., {'reduction': 'mean'}.
|
|
2210
2332
|
loss_weights: Weights for the loss function(s). e.g., 1.0 or [0.7, 0.3].
|
|
2211
2333
|
"""
|
|
2212
|
-
if self.training_mode not in self.support_training_modes:
|
|
2213
|
-
raise ValueError(
|
|
2214
|
-
f"{self.model_name.upper()} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
|
|
2215
|
-
)
|
|
2216
|
-
|
|
2217
2334
|
default_loss_by_mode = {
|
|
2218
2335
|
"pointwise": "bce",
|
|
2219
2336
|
"pairwise": "bpr",
|
nextrec/basic/summary.py
CHANGED
|
@@ -48,6 +48,27 @@ class SummarySet:
|
|
|
48
48
|
checkpoint_path: str
|
|
49
49
|
train_data_summary: dict[str, Any] | None
|
|
50
50
|
valid_data_summary: dict[str, Any] | None
|
|
51
|
+
note: str | None
|
|
52
|
+
|
|
53
|
+
def collect_dataloader_summary(self, data_loader: DataLoader | None):
|
|
54
|
+
if data_loader is None:
|
|
55
|
+
return None
|
|
56
|
+
|
|
57
|
+
summary = {
|
|
58
|
+
"batch_size": data_loader.batch_size,
|
|
59
|
+
"num_workers": data_loader.num_workers,
|
|
60
|
+
"pin_memory": data_loader.pin_memory,
|
|
61
|
+
"persistent_workers": data_loader.persistent_workers,
|
|
62
|
+
}
|
|
63
|
+
prefetch_factor = getattr(data_loader, "prefetch_factor", None)
|
|
64
|
+
if prefetch_factor is not None:
|
|
65
|
+
summary["prefetch_factor"] = prefetch_factor
|
|
66
|
+
|
|
67
|
+
sampler = getattr(data_loader, "sampler", None)
|
|
68
|
+
if sampler is not None:
|
|
69
|
+
summary["sampler"] = sampler.__class__.__name__
|
|
70
|
+
|
|
71
|
+
return summary or None
|
|
51
72
|
|
|
52
73
|
def build_data_summary(
|
|
53
74
|
self, data: Any, data_loader: DataLoader | None, sample_key: str
|
|
@@ -66,6 +87,10 @@ class SummarySet:
|
|
|
66
87
|
if train_size is not None:
|
|
67
88
|
summary[sample_key] = int(train_size)
|
|
68
89
|
|
|
90
|
+
dataloader_summary = self.collect_dataloader_summary(data_loader)
|
|
91
|
+
if dataloader_summary:
|
|
92
|
+
summary["dataloader"] = dataloader_summary
|
|
93
|
+
|
|
69
94
|
if labels:
|
|
70
95
|
task_types = list(self.task) if isinstance(self.task, list) else [self.task]
|
|
71
96
|
if len(task_types) != len(self.target_columns):
|
|
@@ -321,6 +346,7 @@ class SummarySet:
|
|
|
321
346
|
logger.info(f" Session ID: {self.session_id}")
|
|
322
347
|
logger.info(f" Features Config Path: {self.features_config_path}")
|
|
323
348
|
logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
|
|
349
|
+
logger.info(f" Note: {self.note}")
|
|
324
350
|
|
|
325
351
|
if "Data Summary" in selected_sections and (
|
|
326
352
|
self.train_data_summary or self.valid_data_summary
|
|
@@ -341,6 +367,22 @@ class SummarySet:
|
|
|
341
367
|
for label, value in lines:
|
|
342
368
|
logger.info(f" {format_kv(label, value)}")
|
|
343
369
|
|
|
370
|
+
dataloader_info = self.train_data_summary.get("dataloader")
|
|
371
|
+
if isinstance(dataloader_info, dict):
|
|
372
|
+
logger.info("Train DataLoader:")
|
|
373
|
+
for key in (
|
|
374
|
+
"batch_size",
|
|
375
|
+
"num_workers",
|
|
376
|
+
"pin_memory",
|
|
377
|
+
"persistent_workers",
|
|
378
|
+
"sampler",
|
|
379
|
+
):
|
|
380
|
+
if key in dataloader_info:
|
|
381
|
+
label = key.replace("_", " ").title()
|
|
382
|
+
logger.info(
|
|
383
|
+
format_kv(label, dataloader_info[key], indent=2)
|
|
384
|
+
)
|
|
385
|
+
|
|
344
386
|
if self.valid_data_summary:
|
|
345
387
|
if self.train_data_summary:
|
|
346
388
|
logger.info("")
|
|
@@ -355,3 +397,19 @@ class SummarySet:
|
|
|
355
397
|
logger.info(f"{target_name}:")
|
|
356
398
|
for label, value in lines:
|
|
357
399
|
logger.info(f" {format_kv(label, value)}")
|
|
400
|
+
|
|
401
|
+
dataloader_info = self.valid_data_summary.get("dataloader")
|
|
402
|
+
if isinstance(dataloader_info, dict):
|
|
403
|
+
logger.info("Valid DataLoader:")
|
|
404
|
+
for key in (
|
|
405
|
+
"batch_size",
|
|
406
|
+
"num_workers",
|
|
407
|
+
"pin_memory",
|
|
408
|
+
"persistent_workers",
|
|
409
|
+
"sampler",
|
|
410
|
+
):
|
|
411
|
+
if key in dataloader_info:
|
|
412
|
+
label = key.replace("_", " ").title()
|
|
413
|
+
logger.info(
|
|
414
|
+
format_kv(label, dataloader_info[key], indent=2)
|
|
415
|
+
)
|
nextrec/cli.py
CHANGED
|
@@ -320,6 +320,7 @@ def train_model(train_config_path: str) -> None:
|
|
|
320
320
|
streaming=True,
|
|
321
321
|
chunk_size=dataloader_chunk_size,
|
|
322
322
|
num_workers=dataloader_cfg.get("num_workers", 0),
|
|
323
|
+
prefetch_factor=dataloader_cfg.get("prefetch_factor"),
|
|
323
324
|
)
|
|
324
325
|
valid_loader = None
|
|
325
326
|
if val_data_path:
|
|
@@ -331,6 +332,7 @@ def train_model(train_config_path: str) -> None:
|
|
|
331
332
|
streaming=True,
|
|
332
333
|
chunk_size=dataloader_chunk_size,
|
|
333
334
|
num_workers=dataloader_cfg.get("num_workers", 0),
|
|
335
|
+
prefetch_factor=dataloader_cfg.get("prefetch_factor"),
|
|
334
336
|
)
|
|
335
337
|
elif streaming_valid_files:
|
|
336
338
|
valid_loader = dataloader.create_dataloader(
|
|
@@ -340,6 +342,7 @@ def train_model(train_config_path: str) -> None:
|
|
|
340
342
|
streaming=True,
|
|
341
343
|
chunk_size=dataloader_chunk_size,
|
|
342
344
|
num_workers=dataloader_cfg.get("num_workers", 0),
|
|
345
|
+
prefetch_factor=dataloader_cfg.get("prefetch_factor"),
|
|
343
346
|
)
|
|
344
347
|
else:
|
|
345
348
|
train_loader = dataloader.create_dataloader(
|
|
@@ -347,12 +350,14 @@ def train_model(train_config_path: str) -> None:
|
|
|
347
350
|
batch_size=dataloader_cfg.get("train_batch_size", 512),
|
|
348
351
|
shuffle=dataloader_cfg.get("train_shuffle", True),
|
|
349
352
|
num_workers=dataloader_cfg.get("num_workers", 0),
|
|
353
|
+
prefetch_factor=dataloader_cfg.get("prefetch_factor"),
|
|
350
354
|
)
|
|
351
355
|
valid_loader = dataloader.create_dataloader(
|
|
352
356
|
data=valid_data,
|
|
353
357
|
batch_size=dataloader_cfg.get("valid_batch_size", 512),
|
|
354
358
|
shuffle=dataloader_cfg.get("valid_shuffle", False),
|
|
355
359
|
num_workers=dataloader_cfg.get("num_workers", 0),
|
|
360
|
+
prefetch_factor=dataloader_cfg.get("prefetch_factor"),
|
|
356
361
|
)
|
|
357
362
|
|
|
358
363
|
model_cfg.setdefault("session_id", session_id)
|
|
@@ -383,6 +388,7 @@ def train_model(train_config_path: str) -> None:
|
|
|
383
388
|
loss=train_cfg.get("loss", "focal"),
|
|
384
389
|
loss_params=train_cfg.get("loss_params", {}),
|
|
385
390
|
loss_weights=train_cfg.get("loss_weights"),
|
|
391
|
+
ignore_label=train_cfg.get("ignore_label", -1),
|
|
386
392
|
)
|
|
387
393
|
|
|
388
394
|
model.fit(
|
|
@@ -397,6 +403,12 @@ def train_model(train_config_path: str) -> None:
|
|
|
397
403
|
num_workers=dataloader_cfg.get("num_workers", 0),
|
|
398
404
|
user_id_column=id_column,
|
|
399
405
|
use_tensorboard=False,
|
|
406
|
+
use_wandb=train_cfg.get("use_wandb", False),
|
|
407
|
+
use_swanlab=train_cfg.get("use_swanlab", False),
|
|
408
|
+
wandb_api=train_cfg.get("wandb_api"),
|
|
409
|
+
swanlab_api=train_cfg.get("swanlab_api"),
|
|
410
|
+
log_interval=train_cfg.get("log_interval", 1),
|
|
411
|
+
note=train_cfg.get("note"),
|
|
400
412
|
)
|
|
401
413
|
|
|
402
414
|
|
|
@@ -583,6 +595,7 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
583
595
|
shuffle=False,
|
|
584
596
|
streaming=predict_cfg.get("streaming", True),
|
|
585
597
|
chunk_size=predict_cfg.get("chunk_size", 20000),
|
|
598
|
+
prefetch_factor=predict_cfg.get("prefetch_factor"),
|
|
586
599
|
)
|
|
587
600
|
|
|
588
601
|
save_format = predict_cfg.get(
|