nextrec 0.4.24__py3-none-any.whl → 0.4.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/asserts.py +72 -0
  3. nextrec/basic/loggers.py +18 -1
  4. nextrec/basic/model.py +191 -71
  5. nextrec/basic/summary.py +58 -0
  6. nextrec/cli.py +13 -0
  7. nextrec/data/data_processing.py +3 -9
  8. nextrec/data/dataloader.py +25 -2
  9. nextrec/data/preprocessor.py +283 -36
  10. nextrec/models/multi_task/[pre]aitm.py +173 -0
  11. nextrec/models/multi_task/[pre]snr_trans.py +232 -0
  12. nextrec/models/multi_task/[pre]star.py +192 -0
  13. nextrec/models/multi_task/apg.py +330 -0
  14. nextrec/models/multi_task/cross_stitch.py +229 -0
  15. nextrec/models/multi_task/escm.py +290 -0
  16. nextrec/models/multi_task/esmm.py +8 -21
  17. nextrec/models/multi_task/hmoe.py +203 -0
  18. nextrec/models/multi_task/mmoe.py +20 -28
  19. nextrec/models/multi_task/pepnet.py +68 -66
  20. nextrec/models/multi_task/ple.py +30 -44
  21. nextrec/models/multi_task/poso.py +13 -22
  22. nextrec/models/multi_task/share_bottom.py +14 -25
  23. nextrec/models/ranking/afm.py +2 -2
  24. nextrec/models/ranking/autoint.py +2 -4
  25. nextrec/models/ranking/dcn.py +2 -3
  26. nextrec/models/ranking/dcn_v2.py +2 -3
  27. nextrec/models/ranking/deepfm.py +2 -3
  28. nextrec/models/ranking/dien.py +7 -9
  29. nextrec/models/ranking/din.py +8 -10
  30. nextrec/models/ranking/eulernet.py +1 -2
  31. nextrec/models/ranking/ffm.py +1 -2
  32. nextrec/models/ranking/fibinet.py +2 -3
  33. nextrec/models/ranking/fm.py +1 -1
  34. nextrec/models/ranking/lr.py +1 -1
  35. nextrec/models/ranking/masknet.py +1 -2
  36. nextrec/models/ranking/pnn.py +1 -2
  37. nextrec/models/ranking/widedeep.py +2 -3
  38. nextrec/models/ranking/xdeepfm.py +2 -4
  39. nextrec/models/representation/rqvae.py +4 -4
  40. nextrec/models/retrieval/dssm.py +18 -26
  41. nextrec/models/retrieval/dssm_v2.py +15 -22
  42. nextrec/models/retrieval/mind.py +9 -15
  43. nextrec/models/retrieval/sdm.py +36 -33
  44. nextrec/models/retrieval/youtube_dnn.py +16 -24
  45. nextrec/models/sequential/hstu.py +2 -2
  46. nextrec/utils/__init__.py +5 -1
  47. nextrec/utils/config.py +2 -0
  48. nextrec/utils/model.py +16 -77
  49. nextrec/utils/torch_utils.py +11 -0
  50. {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/METADATA +72 -62
  51. nextrec-0.4.27.dist-info/RECORD +90 -0
  52. nextrec/models/multi_task/aitm.py +0 -0
  53. nextrec/models/multi_task/snr_trans.py +0 -0
  54. nextrec-0.4.24.dist-info/RECORD +0 -86
  55. {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/WHEEL +0 -0
  56. {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/entry_points.txt +0 -0
  57. {nextrec-0.4.24.dist-info → nextrec-0.4.27.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 30/12/2025
5
+ Checkpoint: edit on 01/01/2026
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
@@ -36,6 +36,7 @@ from torch.utils.data import DataLoader
36
36
  from torch.utils.data.distributed import DistributedSampler
37
37
 
38
38
  from nextrec import __version__
39
+ from nextrec.basic.asserts import assert_task
39
40
  from nextrec.basic.callback import (
40
41
  CallbackList,
41
42
  CheckpointSaver,
@@ -88,9 +89,8 @@ from nextrec.utils.config import safe_value
88
89
  from nextrec.utils.model import (
89
90
  compute_ranking_loss,
90
91
  get_loss_list,
91
- resolve_loss_weights,
92
- get_training_modes,
93
92
  )
93
+
94
94
  from nextrec.utils.types import (
95
95
  LossName,
96
96
  OptimizerName,
@@ -100,6 +100,8 @@ from nextrec.utils.types import (
100
100
  MetricsName,
101
101
  )
102
102
 
103
+ from nextrec.utils.data import FILE_FORMAT_CONFIG
104
+
103
105
 
104
106
  class BaseModel(SummarySet, FeatureSet, nn.Module):
105
107
  @property
@@ -118,7 +120,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
118
120
  target: list[str] | str | None = None,
119
121
  id_columns: list[str] | str | None = None,
120
122
  task: TaskTypeName | list[TaskTypeName] | None = None,
121
- training_mode: TrainingModeName | list[TrainingModeName] = "pointwise",
123
+ training_mode: TrainingModeName | list[TrainingModeName] | None = None,
122
124
  embedding_l1_reg: float = 0.0,
123
125
  dense_l1_reg: float = 0.0,
124
126
  embedding_l2_reg: float = 0.0,
@@ -138,10 +140,10 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
138
140
  dense_features: DenseFeature definitions.
139
141
  sparse_features: SparseFeature definitions.
140
142
  sequence_features: SequenceFeature definitions.
141
- target: Target column name. e.g., 'label' or ['label1', 'label2'].
143
+ target: Target column name. e.g., 'label_ctr' or ['label_ctr', 'label_cvr'].
142
144
  id_columns: Identifier column name, only need to specify if GAUC is required. e.g., 'user_id'.
143
145
  task: Task types, e.g., 'binary', 'regression', or ['binary', 'regression']. If None, falls back to self.default_task.
144
- training_mode: Training mode for ranking tasks; a single mode or a list per task.
146
+ training_mode: Training mode for different tasks. e.g., 'pointwise', ['pointwise', 'pairwise'].
145
147
 
146
148
  embedding_l1_reg: L1 regularization strength for embedding params. e.g., 1e-6.
147
149
  dense_l1_reg: L1 regularization strength for dense params. e.g., 1e-5.
@@ -193,10 +195,12 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
193
195
 
194
196
  self.task = task or self.default_task
195
197
  self.nums_task = len(self.task) if isinstance(self.task, list) else 1
196
- self.training_modes = get_training_modes(training_mode, self.nums_task)
197
- self.training_mode = (
198
- self.training_modes if self.nums_task > 1 else self.training_modes[0]
199
- )
198
+
199
+ training_mode = training_mode or "pointwise"
200
+ if isinstance(training_mode, list):
201
+ self.training_modes = list(training_mode)
202
+ else:
203
+ self.training_modes = [training_mode] * self.nums_task
200
204
 
201
205
  self.embedding_l1_reg = embedding_l1_reg
202
206
  self.dense_l1_reg = dense_l1_reg
@@ -215,6 +219,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
215
219
 
216
220
  self.train_data_summary = None
217
221
  self.valid_data_summary = None
222
+ self.note = None
218
223
 
219
224
  def register_regularization_weights(
220
225
  self,
@@ -222,6 +227,15 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
222
227
  exclude_modules: list[str] | None = None,
223
228
  include_modules: list[str] | None = None,
224
229
  ):
230
+ """
231
+ Register parameters for regularization.
232
+ By default, all nn.Linear weights (excluding those in BatchNorm/Dropout layers) and embedding weights under `embedding_attr` are registered.
233
+
234
+ Args:
235
+ embedding_attr: Attribute name of the embedding layer/module.
236
+ exclude_modules: List of module name substrings to exclude from regularization.
237
+ include_modules: List of module name substrings to include for regularization. If provided, only modules containing these substrings are included.
238
+ """
225
239
  exclude_modules = exclude_modules or []
226
240
  include_modules = include_modules or []
227
241
  embedding_layer = getattr(self, embedding_attr, None)
@@ -268,6 +282,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
268
282
  existing_reg_ids.add(id(module.weight))
269
283
 
270
284
  def add_reg_loss(self) -> torch.Tensor:
285
+ """
286
+ Compute the regularization loss based on registered parameters and their respective regularization strengths.
287
+ """
271
288
  reg_loss = torch.tensor(0.0, device=self.device)
272
289
 
273
290
  if self.embedding_l1_reg > 0:
@@ -289,9 +306,25 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
289
306
  )
290
307
  return reg_loss
291
308
 
309
+ # todo: support build pairwise/listwise label in input
292
310
  def get_input(self, input_data: dict, require_labels: bool = True):
311
+ """
312
+ Prepare unified input features and labels from the given input data.
313
+
314
+
315
+ Args:
316
+ input_data: Input data dictionary containing 'features' and optionally 'labels', e.g., {'features': {'feat1': [...], 'feat2': [...]}, 'labels': {'label': [...]}}.
317
+ require_labels: Whether labels are required in the input data. Default is True: for training and evaluation with labels.
318
+
319
+ Note:
320
+ target tensor shape will always be (batch_size, num_targets)
321
+ """
293
322
  feature_source = input_data.get("features", {})
323
+ # todo: pairwise/listwise label support
324
+ # "labels": {...} should contain pointwise/pair index/list index/ relevance scores
325
+ # now only have pointwise label support
294
326
  label_source = input_data.get("labels")
327
+
295
328
  X_input = {}
296
329
  for feature in self.all_features:
297
330
  if feature.name not in feature_source:
@@ -307,13 +340,14 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
307
340
  device=self.device,
308
341
  )
309
342
  y = None
343
+ # if need labels: training or eval with labels
310
344
  if len(self.target_columns) > 0 and (
311
345
  require_labels
312
346
  or (
313
347
  label_source
314
348
  and any(name in label_source for name in self.target_columns)
315
349
  )
316
- ): # need labels: training or eval with labels
350
+ ):
317
351
  target_tensors = []
318
352
  for target_name in self.target_columns:
319
353
  if label_source is None or target_name not in label_source:
@@ -358,6 +392,10 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
358
392
  This function will split training data into training and validation sets when:
359
393
  1. valid_data is None;
360
394
  2. valid_split is provided.
395
+
396
+ Returns:
397
+ train_loader: DataLoader for training data.
398
+ valid_split_data: Validation data dict/dataframe split from training data.
361
399
  """
362
400
  if not (0 < valid_split < 1):
363
401
  raise ValueError(
@@ -375,7 +413,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
375
413
  )
376
414
  else:
377
415
  raise TypeError(
378
- f"[BaseModel-validation Error] If you want to use valid_split, train_data must be a pandas DataFrame or a dict instead of {type(train_data)}"
416
+ f"[BaseModel-validation Error] If you want to use valid_split, train_data must be DataFrame or a dict, now got {type(train_data)}"
379
417
  )
380
418
  rng = np.random.default_rng(42)
381
419
  indices = rng.permutation(total_length)
@@ -426,7 +464,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
426
464
  Args:
427
465
  optimizer: Optimizer name or instance. e.g., 'adam', 'sgd', or torch.optim.Adam().
428
466
  optimizer_params: Optimizer parameters. e.g., {'lr': 1e-3, 'weight_decay': 1e-5}.
429
- scheduler: Learning rate scheduler name or instance. e.g., 'step_lr', 'cosine_annealing', or torch.optim.lr_scheduler.StepLR().
467
+ scheduler: Learning rate scheduler name or instance. e.g., 'step', 'cosine', or torch.optim.lr_scheduler.StepLR().
430
468
  scheduler_params: Scheduler parameters. e.g., {'step_size': 10, 'gamma': 0.1}.
431
469
  loss: Loss function name, instance, or list for multi-task. e.g., 'bce', 'mse', or torch.nn.BCELoss(), you can also use custom loss functions.
432
470
  loss_params: Loss function parameters, or list for multi-task. e.g., {'weight': tensor([0.25, 0.75])}.
@@ -435,36 +473,31 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
435
473
  ignore_label: Label value to ignore when computing loss. Use this to skip gradients for unknown labels.
436
474
  """
437
475
  self.ignore_label = ignore_label
438
- default_losses = {
439
- "pointwise": "bce",
440
- "pairwise": "bpr",
441
- "listwise": "listnet",
442
- }
443
- loss_list = get_loss_list(
444
- loss, self.training_modes, self.nums_task, default_losses
445
- )
446
- self.loss_params = loss_params or {}
447
- optimizer_params = optimizer_params or {}
476
+
477
+ # get loss list
478
+ loss_list = get_loss_list(loss, self.training_modes, self.nums_task)
479
+
480
+ self.loss_params = {} if loss_params is None else loss_params
481
+ self.optimizer_params = optimizer_params or {}
482
+ self.scheduler_params = scheduler_params or {}
483
+
448
484
  self.optimizer_name = (
449
485
  optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
450
486
  )
451
- self.optimizer_params = optimizer_params
452
487
  self.optimizer_fn = get_optimizer(
453
488
  optimizer=optimizer,
454
489
  params=self.parameters(),
455
- **optimizer_params,
490
+ **self.optimizer_params,
456
491
  )
457
492
 
458
- scheduler_params = scheduler_params or {}
459
493
  if scheduler is None:
460
494
  self.scheduler_name = None
461
495
  elif isinstance(scheduler, str):
462
496
  self.scheduler_name = scheduler
463
497
  else:
464
498
  self.scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__) # type: ignore
465
- self.scheduler_params = scheduler_params
466
499
  self.scheduler_fn = (
467
- get_scheduler(scheduler, self.optimizer_fn, **scheduler_params)
500
+ get_scheduler(scheduler, self.optimizer_fn, **self.scheduler_params)
468
501
  if scheduler
469
502
  else None
470
503
  )
@@ -482,35 +515,56 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
482
515
  for i in range(self.nums_task)
483
516
  ]
484
517
 
518
+ # loss weighting (grad norm or fixed weights)
485
519
  self.grad_norm = None
486
520
  self.grad_norm_shared_params = None
487
- if isinstance(loss_weights, str) and loss_weights.lower() == "grad_norm":
521
+ is_grad_norm = (
522
+ loss_weights == "grad_norm"
523
+ or isinstance(loss_weights, dict)
524
+ and loss_weights.get("method") == "grad_norm"
525
+ )
526
+ if is_grad_norm:
488
527
  if self.nums_task == 1:
489
528
  raise ValueError(
490
529
  "[BaseModel-compile Error] GradNorm requires multi-task setup."
491
530
  )
492
- self.grad_norm = GradNormLossWeighting(
493
- nums_task=self.nums_task, device=self.device
531
+ grad_norm_params = (
532
+ dict(loss_weights) if isinstance(loss_weights, dict) else {}
494
533
  )
495
- self.loss_weights = None
496
- elif (
497
- isinstance(loss_weights, dict) and loss_weights.get("method") == "grad_norm"
498
- ):
499
- if self.nums_task == 1:
500
- raise ValueError(
501
- "[BaseModel-compile Error] GradNorm requires multi-task setup."
502
- )
503
- grad_norm_params = dict(loss_weights)
504
534
  grad_norm_params.pop("method", None)
505
535
  self.grad_norm = GradNormLossWeighting(
506
536
  nums_task=self.nums_task, device=self.device, **grad_norm_params
507
537
  )
508
538
  self.loss_weights = None
539
+ elif loss_weights is None:
540
+ self.loss_weights = None
541
+ elif self.nums_task == 1:
542
+ if isinstance(loss_weights, (list, tuple)):
543
+ if len(loss_weights) != 1:
544
+ raise ValueError(
545
+ "[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup."
546
+ )
547
+ loss_weights = loss_weights[0]
548
+ self.loss_weights = [float(loss_weights)]
549
+ elif isinstance(loss_weights, (int, float)):
550
+ self.loss_weights = [float(loss_weights)] * self.nums_task
551
+ elif isinstance(loss_weights, (list, tuple)):
552
+ weights = [float(w) for w in loss_weights]
553
+ if len(weights) != self.nums_task:
554
+ raise ValueError(
555
+ f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task})."
556
+ )
557
+ self.loss_weights = weights
509
558
  else:
510
- self.loss_weights = resolve_loss_weights(loss_weights, self.nums_task)
559
+ raise TypeError(
560
+ f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}"
561
+ )
511
562
  self.compiled = True
512
563
 
513
564
  def compute_loss(self, y_pred, y_true):
565
+ """
566
+ Compute the loss between predictions and ground truth labels, with loss weighting and ignore_label handling
567
+ """
514
568
  if y_true is None:
515
569
  raise ValueError(
516
570
  "[BaseModel-compute_loss Error] Ground truth labels (y_true) are required."
@@ -522,13 +576,11 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
522
576
  y_pred = y_pred.view(-1, 1)
523
577
  if y_true.dim() == 1:
524
578
  y_true = y_true.view(-1, 1)
525
- if y_pred.shape != y_true.shape:
526
- raise ValueError(
527
- f"[BaseModel-compute_loss Error] Shape mismatch: {y_pred.shape} vs {y_true.shape}"
528
- )
529
579
 
530
580
  loss_fn = self.loss_fn[0]
531
581
 
582
+ # mask ignored labels
583
+ # we don't suggest using ignore_label for single task training
532
584
  if self.ignore_label is not None:
533
585
  valid_mask = y_true != self.ignore_label
534
586
  if valid_mask.dim() > 1:
@@ -559,9 +611,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
559
611
  loss *= self.loss_weights[0]
560
612
  return loss
561
613
 
562
- # multi-task
563
- if y_pred.shape != y_true.shape:
564
- raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
614
+ # multi-task: slice predictions and labels per task
565
615
  slices = (
566
616
  self.prediction_layer.task_slices # type: ignore
567
617
  if hasattr(self, "prediction_layer")
@@ -593,9 +643,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
593
643
  )
594
644
  else:
595
645
  task_loss = self.loss_fn[i](y_pred_i, y_true_i)
596
- # task_loss = normalize_task_loss(
597
- # task_loss, valid_count, total_count
598
- # ) # normalize by valid samples to avoid loss scale issues
646
+ # task_loss = normalize_task_loss(
647
+ # task_loss, valid_count, total_count
648
+ # ) # normalize by valid samples to avoid loss scale issues
599
649
  task_losses.append(task_loss)
600
650
 
601
651
  if self.grad_norm is not None:
@@ -619,11 +669,23 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
619
669
  batch_size: int = 32,
620
670
  shuffle: bool = True,
621
671
  num_workers: int = 0,
672
+ prefetch_factor: int | None = None,
622
673
  sampler=None,
623
674
  return_dataset: bool = False,
624
675
  ):
625
676
  """
626
677
  Prepare a DataLoader from input data. Only used when input data is not a DataLoader.
678
+
679
+ Args:
680
+ data: Input data (dict/df/DataLoader).
681
+ batch_size: Batch size.
682
+ shuffle: Whether to shuffle the data (ignored when a sampler is provided).
683
+ num_workers: Number of DataLoader workers.
684
+ prefetch_factor: Number of batches loaded in advance by each worker.
685
+ sampler: Optional sampler for DataLoader.
686
+ return_dataset: Whether to return the tensor dataset along with the DataLoader, used for valid data
687
+ Returns:
688
+ DataLoader (and tensor dataset if return_dataset is True).
627
689
  """
628
690
  if isinstance(data, DataLoader):
629
691
  return (data, None) if return_dataset else data
@@ -639,6 +701,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
639
701
  "[BaseModel-prepare_data_loader Error] No data available to create DataLoader."
640
702
  )
641
703
  dataset = TensorDictDataset(tensors)
704
+ loader_kwargs = {}
705
+ if num_workers > 0 and prefetch_factor is not None:
706
+ loader_kwargs["prefetch_factor"] = prefetch_factor
642
707
  loader = DataLoader(
643
708
  dataset,
644
709
  batch_size=batch_size,
@@ -648,6 +713,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
648
713
  num_workers=num_workers,
649
714
  pin_memory=self.device.type == "cuda",
650
715
  persistent_workers=num_workers > 0,
716
+ **loader_kwargs,
651
717
  )
652
718
  return (loader, dataset) if return_dataset else loader
653
719
 
@@ -676,6 +742,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
676
742
  swanlab_kwargs: dict | None = None,
677
743
  auto_ddp_sampler: bool = True,
678
744
  log_interval: int = 1,
745
+ note: str | None = None,
679
746
  summary_sections: (
680
747
  list[Literal["feature", "model", "train", "data"]] | None
681
748
  ) = None,
@@ -707,6 +774,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
707
774
  swanlab_kwargs: Optional kwargs for swanlab.init(...).
708
775
  auto_ddp_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
709
776
  log_interval: Log validation metrics every N epochs (still computes metrics each epoch).
777
+ note: Optional note for the training run.
710
778
  summary_sections: Optional summary sections to print. Choose from
711
779
  ["feature", "model", "train", "data"]. Defaults to all.
712
780
 
@@ -720,6 +788,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
720
788
  )
721
789
  self.to(self.device)
722
790
 
791
+ assert_task(self.task, len(self.target_columns), model_name=self.model_name)
792
+
723
793
  if not self.compiled:
724
794
  self.compile(
725
795
  optimizer="adam",
@@ -770,11 +840,13 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
770
840
  self.metrics_sample_limit = (
771
841
  None if metrics_sample_limit is None else int(metrics_sample_limit)
772
842
  )
843
+ self.note = note
773
844
 
774
845
  training_config = {}
775
846
  if self.is_main_process:
776
847
  training_config = {
777
848
  "model_name": getattr(self, "model_name", self.__class__.__name__),
849
+ "note": self.note,
778
850
  "task": self.task,
779
851
  "target_columns": self.target_columns,
780
852
  "batch_size": batch_size,
@@ -822,6 +894,14 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
822
894
  else:
823
895
  swanlab.login(api_key=swanlab_api)
824
896
 
897
+ if use_wandb and self.note:
898
+ wandb_kwargs = dict(wandb_kwargs or {})
899
+ wandb_kwargs.setdefault("notes", self.note)
900
+
901
+ if use_swanlab and self.note:
902
+ swanlab_kwargs = dict(swanlab_kwargs or {})
903
+ swanlab_kwargs.setdefault("description", self.note)
904
+
825
905
  self.training_logger = (
826
906
  TrainingLogger(
827
907
  session=self.session,
@@ -1253,7 +1333,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1253
1333
  for batch_index, batch_data in batch_iter:
1254
1334
  batch_dict = batch_to_dict(batch_data)
1255
1335
  X_input, y_true = self.get_input(batch_dict, require_labels=True)
1256
- # call via __call__ so DDP hooks run (no grad sync if calling .forward directly)
1336
+ # call via __call__ so DDP hooks run
1257
1337
  y_pred = model(X_input) # type: ignore
1258
1338
 
1259
1339
  loss = self.compute_loss(y_pred, y_true)
@@ -1556,7 +1636,6 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1556
1636
  num_workers: int = 0,
1557
1637
  ) -> pd.DataFrame | np.ndarray | Path | None:
1558
1638
  """
1559
- Note: predict does not support distributed mode currently, consider it as a single-process operation.
1560
1639
  Make predictions on the given data.
1561
1640
 
1562
1641
  Args:
@@ -1569,6 +1648,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1569
1648
  return_dataframe: Whether to return predictions as a pandas DataFrame; if False, returns a NumPy array.
1570
1649
  stream_chunk_size: Number of rows per chunk when using streaming mode for large datasets.
1571
1650
  num_workers: DataLoader worker count.
1651
+
1652
+ Note:
1653
+ predict does not support distributed mode currently, consider it as a single-process operation.
1572
1654
  """
1573
1655
  self.eval()
1574
1656
  # Use prediction-time id_columns if provided, otherwise fall back to model's id_columns
@@ -1753,6 +1835,21 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1753
1835
  return_dataframe: bool,
1754
1836
  id_columns: list[str] | None = None,
1755
1837
  ):
1838
+ """
1839
+ Make predictions on the given data using streaming mode for large datasets.
1840
+
1841
+ Args:
1842
+ data: Input data for prediction (file path, dict, DataFrame, or DataLoader).
1843
+ batch_size: Batch size for prediction.
1844
+ save_path: Path to save predictions.
1845
+ save_format: Format to save predictions ('csv' or 'parquet').
1846
+ include_ids: Whether to include ID columns in the output.
1847
+ stream_chunk_size: Number of rows per chunk when using streaming mode.
1848
+ return_dataframe: Whether to return predictions as a pandas DataFrame.
1849
+ id_columns: Column name(s) to use as IDs; if None, uses model's id_columns.
1850
+ Note:
1851
+ This method uses streaming writes to handle large datasets without loading all data into memory.
1852
+ """
1756
1853
  if isinstance(data, (str, os.PathLike)):
1757
1854
  rec_loader = RecDataLoader(
1758
1855
  dense_features=self.dense_features,
@@ -1795,8 +1892,6 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1795
1892
  "Results will be collected in memory and saved at the end. Use csv or parquet for true streaming."
1796
1893
  )
1797
1894
 
1798
- from nextrec.utils.data import FILE_FORMAT_CONFIG
1799
-
1800
1895
  suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
1801
1896
 
1802
1897
  target_path = get_save_path(
@@ -1908,6 +2003,14 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1908
2003
  add_timestamp: bool | None = None,
1909
2004
  verbose: bool = True,
1910
2005
  ):
2006
+ """
2007
+ Save the model state and features configuration to disk.
2008
+
2009
+ Args:
2010
+ save_path: Path to save the model; if None, saves to the session's model directory.
2011
+ add_timestamp: Whether to add a timestamp to the filename; if None, defaults to True.
2012
+ verbose: Whether to log the save location.
2013
+ """
1911
2014
  add_timestamp = False if add_timestamp is None else add_timestamp
1912
2015
  target_path = get_save_path(
1913
2016
  path=save_path,
@@ -1950,6 +2053,14 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1950
2053
  map_location: str | torch.device | None = "cpu",
1951
2054
  verbose: bool = True,
1952
2055
  ):
2056
+ """
2057
+ Load the model state and features configuration from disk.
2058
+
2059
+ Args:
2060
+ save_path: Path to load the model from; can be a directory or a specific .pt file.
2061
+ map_location: Device mapping for loading the model (e.g., 'cpu', 'cuda:0').
2062
+ verbose: Whether to log the load location.
2063
+ """
1953
2064
  self.to(self.device)
1954
2065
  base_path = Path(save_path)
1955
2066
  if base_path.is_dir():
@@ -2016,6 +2127,13 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
2016
2127
  """
2017
2128
  Load a model from a checkpoint path. The checkpoint path should contain:
2018
2129
  a .pt file and a features_config.pkl file.
2130
+
2131
+ Args:
2132
+ checkpoint_path: Path to the checkpoint directory or specific .pt file.
2133
+ map_location: Device mapping for loading the model (e.g., 'cpu', 'cuda:0').
2134
+ device: Device to place the model on after loading.
2135
+ session_id: Optional session ID for the model.
2136
+ **kwargs: Additional keyword arguments to pass to the model constructor.
2019
2137
  """
2020
2138
  base_path = Path(checkpoint_path)
2021
2139
  verbose = kwargs.pop("verbose", True)
@@ -2135,6 +2253,7 @@ class BaseMatchModel(BaseModel):
2135
2253
  target=target,
2136
2254
  id_columns=id_columns,
2137
2255
  task=task,
2256
+ training_mode=training_mode,
2138
2257
  device=device,
2139
2258
  embedding_l1_reg=embedding_l1_reg,
2140
2259
  dense_l1_reg=dense_l1_reg,
@@ -2157,10 +2276,14 @@ class BaseMatchModel(BaseModel):
2157
2276
  self.item_sparse_features = item_sparse_features
2158
2277
  self.item_sequence_features = item_sequence_features
2159
2278
 
2160
- self.training_mode = training_mode
2161
2279
  self.num_negative_samples = num_negative_samples
2162
2280
  self.temperature = temperature
2163
2281
  self.similarity_metric = similarity_metric
2282
+ primary_mode = self.training_modes[0] if self.training_modes else "pointwise"
2283
+ if primary_mode not in self.support_training_modes:
2284
+ raise ValueError(
2285
+ f"{self.model_name.upper()} does not support training_mode='{primary_mode}'. Supported modes: {self.support_training_modes}"
2286
+ )
2164
2287
  self.user_features_all = (
2165
2288
  self.user_dense_features
2166
2289
  + self.user_sparse_features
@@ -2176,7 +2299,7 @@ class BaseMatchModel(BaseModel):
2176
2299
  self.head = RetrievalHead(
2177
2300
  similarity_metric=self.similarity_metric,
2178
2301
  temperature=self.temperature,
2179
- training_mode=self.training_mode,
2302
+ training_mode=primary_mode,
2180
2303
  apply_sigmoid=True,
2181
2304
  )
2182
2305
 
@@ -2209,11 +2332,6 @@ class BaseMatchModel(BaseModel):
2209
2332
  loss_params: Parameters for the loss function(s). e.g., {'reduction': 'mean'}.
2210
2333
  loss_weights: Weights for the loss function(s). e.g., 1.0 or [0.7, 0.3].
2211
2334
  """
2212
- if self.training_mode not in self.support_training_modes:
2213
- raise ValueError(
2214
- f"{self.model_name.upper()} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
2215
- )
2216
-
2217
2335
  default_loss_by_mode = {
2218
2336
  "pointwise": "bce",
2219
2337
  "pairwise": "bpr",
@@ -2221,26 +2339,27 @@ class BaseMatchModel(BaseModel):
2221
2339
  }
2222
2340
 
2223
2341
  effective_loss = loss
2342
+ primary_mode = self.training_modes[0] if self.training_modes else "pointwise"
2224
2343
  if effective_loss is None:
2225
- effective_loss = default_loss_by_mode[self.training_mode]
2344
+ effective_loss = default_loss_by_mode[primary_mode]
2226
2345
  elif isinstance(effective_loss, str):
2227
- if self.training_mode in {"pairwise", "listwise"} and effective_loss in {
2346
+ if primary_mode in {"pairwise", "listwise"} and effective_loss in {
2228
2347
  "bce",
2229
2348
  "binary_crossentropy",
2230
2349
  }:
2231
- effective_loss = default_loss_by_mode[self.training_mode]
2350
+ effective_loss = default_loss_by_mode[primary_mode]
2232
2351
  elif isinstance(effective_loss, list):
2233
2352
  if not effective_loss:
2234
- effective_loss = [default_loss_by_mode[self.training_mode]]
2353
+ effective_loss = [default_loss_by_mode[primary_mode]]
2235
2354
  else:
2236
2355
  first = effective_loss[0]
2237
2356
  if (
2238
- self.training_mode in {"pairwise", "listwise"}
2357
+ primary_mode in {"pairwise", "listwise"}
2239
2358
  and isinstance(first, str)
2240
2359
  and first in {"bce", "binary_crossentropy"}
2241
2360
  ):
2242
2361
  effective_loss = [
2243
- default_loss_by_mode[self.training_mode],
2362
+ default_loss_by_mode[primary_mode],
2244
2363
  *effective_loss[1:],
2245
2364
  ]
2246
2365
  return super().compile(
@@ -2318,11 +2437,12 @@ class BaseMatchModel(BaseModel):
2318
2437
  return self.head(user_emb, item_emb, similarity_fn=self.compute_similarity)
2319
2438
 
2320
2439
  def compute_loss(self, y_pred, y_true):
2321
- if self.training_mode == "pointwise":
2440
+ primary_mode = self.training_modes[0] if self.training_modes else "pointwise"
2441
+ if primary_mode == "pointwise":
2322
2442
  return super().compute_loss(y_pred, y_true)
2323
2443
 
2324
2444
  # pairwise / listwise using inbatch neg
2325
- elif self.training_mode in ["pairwise", "listwise"]:
2445
+ elif primary_mode in ["pairwise", "listwise"]:
2326
2446
  if not isinstance(y_pred, (tuple, list)) or len(y_pred) != 2:
2327
2447
  raise ValueError(
2328
2448
  "For pairwise/listwise training, forward should return (user_emb, item_emb). Please check BaseMatchModel.forward implementation."
@@ -2365,7 +2485,7 @@ class BaseMatchModel(BaseModel):
2365
2485
  loss *= float(self.loss_weights[0])
2366
2486
  return loss
2367
2487
  else:
2368
- raise ValueError(f"Unknown training mode: {self.training_mode}")
2488
+ raise ValueError(f"Unknown training mode: {primary_mode}")
2369
2489
 
2370
2490
  def prepare_feature_data(
2371
2491
  self,