nextrec 0.4.16__py3-none-any.whl → 0.4.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/heads.py +99 -0
  3. nextrec/basic/loggers.py +5 -5
  4. nextrec/basic/model.py +217 -88
  5. nextrec/cli.py +1 -1
  6. nextrec/data/dataloader.py +93 -95
  7. nextrec/data/preprocessor.py +108 -46
  8. nextrec/loss/grad_norm.py +13 -13
  9. nextrec/models/multi_task/esmm.py +10 -11
  10. nextrec/models/multi_task/mmoe.py +20 -19
  11. nextrec/models/multi_task/ple.py +35 -34
  12. nextrec/models/multi_task/poso.py +23 -21
  13. nextrec/models/multi_task/share_bottom.py +18 -17
  14. nextrec/models/ranking/afm.py +4 -3
  15. nextrec/models/ranking/autoint.py +4 -3
  16. nextrec/models/ranking/dcn.py +4 -3
  17. nextrec/models/ranking/dcn_v2.py +4 -3
  18. nextrec/models/ranking/deepfm.py +4 -3
  19. nextrec/models/ranking/dien.py +2 -2
  20. nextrec/models/ranking/din.py +2 -2
  21. nextrec/models/ranking/eulernet.py +4 -3
  22. nextrec/models/ranking/ffm.py +4 -3
  23. nextrec/models/ranking/fibinet.py +2 -2
  24. nextrec/models/ranking/fm.py +4 -3
  25. nextrec/models/ranking/lr.py +4 -3
  26. nextrec/models/ranking/masknet.py +4 -5
  27. nextrec/models/ranking/pnn.py +5 -4
  28. nextrec/models/ranking/widedeep.py +8 -8
  29. nextrec/models/ranking/xdeepfm.py +5 -4
  30. nextrec/utils/console.py +20 -6
  31. nextrec/utils/data.py +154 -32
  32. nextrec/utils/model.py +86 -1
  33. {nextrec-0.4.16.dist-info → nextrec-0.4.18.dist-info}/METADATA +5 -6
  34. {nextrec-0.4.16.dist-info → nextrec-0.4.18.dist-info}/RECORD +37 -36
  35. {nextrec-0.4.16.dist-info → nextrec-0.4.18.dist-info}/WHEEL +0 -0
  36. {nextrec-0.4.16.dist-info → nextrec-0.4.18.dist-info}/entry_points.txt +0 -0
  37. {nextrec-0.4.16.dist-info → nextrec-0.4.18.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 20/12/2025
5
+ Checkpoint: edit on 24/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
@@ -38,6 +38,7 @@ from nextrec.basic.features import (
38
38
  SequenceFeature,
39
39
  SparseFeature,
40
40
  )
41
+ from nextrec.basic.heads import RetrievalHead
41
42
  from nextrec.basic.loggers import TrainingLogger, colorize, format_kv, setup_logger
42
43
  from nextrec.basic.metrics import check_user_id, configure_metrics, evaluate_metrics
43
44
  from nextrec.basic.session import create_session, resolve_save_path
@@ -48,6 +49,7 @@ from nextrec.data.dataloader import (
48
49
  TensorDictDataset,
49
50
  build_tensors_from_data,
50
51
  )
52
+ from nextrec.utils.data import check_streaming_support
51
53
  from nextrec.loss import (
52
54
  BPRLoss,
53
55
  GradNormLossWeighting,
@@ -68,6 +70,7 @@ from nextrec.utils.torch_utils import (
68
70
  init_process_group,
69
71
  to_tensor,
70
72
  )
73
+ from nextrec.utils.model import compute_ranking_loss
71
74
 
72
75
 
73
76
  class BaseModel(FeatureSet, nn.Module):
@@ -87,13 +90,18 @@ class BaseModel(FeatureSet, nn.Module):
87
90
  target: list[str] | str | None = None,
88
91
  id_columns: list[str] | str | None = None,
89
92
  task: str | list[str] | None = None,
93
+ training_mode: (
94
+ Literal["pointwise", "pairwise", "listwise"]
95
+ | list[Literal["pointwise", "pairwise", "listwise"]]
96
+ ) = "pointwise",
90
97
  embedding_l1_reg: float = 0.0,
91
98
  dense_l1_reg: float = 0.0,
92
99
  embedding_l2_reg: float = 0.0,
93
100
  dense_l2_reg: float = 0.0,
94
101
  device: str = "cpu",
95
102
  early_stop_patience: int = 20,
96
- max_metrics_samples: int | None = 200000,
103
+ early_stop_monitor_task: str | None = None,
104
+ metrics_sample_limit: int | None = 200000,
97
105
  session_id: str | None = None,
98
106
  callbacks: list[Callback] | None = None,
99
107
  distributed: bool = False,
@@ -112,6 +120,7 @@ class BaseModel(FeatureSet, nn.Module):
112
120
  target: Target column name. e.g., 'label' or ['label1', 'label2'].
113
121
  id_columns: Identifier column name, only need to specify if GAUC is required. e.g., 'user_id'.
114
122
  task: Task types, e.g., 'binary', 'regression', or ['binary', 'regression']. If None, falls back to self.default_task.
123
+ training_mode: Training mode for ranking tasks; a single mode or a list per task.
115
124
 
116
125
  embedding_l1_reg: L1 regularization strength for embedding params. e.g., 1e-6.
117
126
  dense_l1_reg: L1 regularization strength for dense params. e.g., 1e-5.
@@ -120,7 +129,8 @@ class BaseModel(FeatureSet, nn.Module):
120
129
 
121
130
  device: Torch device string or torch.device. e.g., 'cpu', 'cuda:0'.
122
131
  early_stop_patience: Epochs for early stopping. 0 to disable. e.g., 20.
123
- max_metrics_samples: Max samples to keep for training metrics. None disables limit.
132
+ early_stop_monitor_task: Task name to monitor for early stopping in multi-task scenario. If None, uses first target. e.g., 'click'.
133
+ metrics_sample_limit: Max samples to keep for training metrics. None disables limit.
124
134
  session_id: Session id for logging. If None, a default id with timestamps will be created. e.g., 'session_tutorial'.
125
135
  callbacks: List of callback instances. If None, default callbacks will be created. e.g., [EarlyStopper(), CheckpointSaver()].
126
136
 
@@ -149,9 +159,11 @@ class BaseModel(FeatureSet, nn.Module):
149
159
  self.session = create_session(session_id)
150
160
  self.session_path = self.session.root # pwd/session_id, path for this session
151
161
  self.checkpoint_path = os.path.join(
152
- self.session_path, self.model_name + "_checkpoint.pt"
162
+ self.session_path, self.model_name.upper() + "_checkpoint.pt"
153
163
  ) # e.g., pwd/session_id/DeepFM_checkpoint.pt
154
- self.best_path = os.path.join(self.session_path, self.model_name + "_best.pt")
164
+ self.best_path = os.path.join(
165
+ self.session_path, self.model_name.upper() + "_best.pt"
166
+ )
155
167
  self.features_config_path = os.path.join(
156
168
  self.session_path, "features_config.pkl"
157
169
  )
@@ -161,6 +173,22 @@ class BaseModel(FeatureSet, nn.Module):
161
173
 
162
174
  self.task = self.default_task if task is None else task
163
175
  self.nums_task = len(self.task) if isinstance(self.task, list) else 1
176
+ if isinstance(training_mode, list):
177
+ if len(training_mode) != self.nums_task:
178
+ raise ValueError(
179
+ "[BaseModel-init Error] training_mode list length must match number of tasks."
180
+ )
181
+ self.training_modes = list(training_mode)
182
+ else:
183
+ self.training_modes = [training_mode] * self.nums_task
184
+ for mode in self.training_modes:
185
+ if mode not in {"pointwise", "pairwise", "listwise"}:
186
+ raise ValueError(
187
+ "[BaseModel-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
188
+ )
189
+ self.training_mode = (
190
+ self.training_modes if self.nums_task > 1 else self.training_modes[0]
191
+ )
164
192
 
165
193
  self.embedding_l1_reg = embedding_l1_reg
166
194
  self.dense_l1_reg = dense_l1_reg
@@ -171,9 +199,10 @@ class BaseModel(FeatureSet, nn.Module):
171
199
  self.loss_weight = None
172
200
 
173
201
  self.early_stop_patience = early_stop_patience
202
+ self.early_stop_monitor_task = early_stop_monitor_task
174
203
  # max samples to keep for training metrics, in case of large training set
175
- self.max_metrics_samples = (
176
- None if max_metrics_samples is None else int(max_metrics_samples)
204
+ self.metrics_sample_limit = (
205
+ None if metrics_sample_limit is None else int(metrics_sample_limit)
177
206
  )
178
207
  self.max_gradient_norm = 1.0
179
208
  self.logger_initialized = False
@@ -397,6 +426,33 @@ class BaseModel(FeatureSet, nn.Module):
397
426
  Use "grad_norm" or {"method": "grad_norm", ...} to enable GradNorm for multi-task loss balancing.
398
427
  callbacks: Additional callbacks to add to the existing callback list. e.g., [EarlyStopper(), CheckpointSaver()].
399
428
  """
429
+ default_losses = {
430
+ "pointwise": "bce",
431
+ "pairwise": "bpr",
432
+ "listwise": "listnet",
433
+ }
434
+ effective_loss = loss
435
+ if effective_loss is None:
436
+ loss_list = [default_losses[mode] for mode in self.training_modes]
437
+ elif isinstance(effective_loss, list):
438
+ if not effective_loss:
439
+ loss_list = [default_losses[mode] for mode in self.training_modes]
440
+ else:
441
+ if len(effective_loss) != self.nums_task:
442
+ raise ValueError(
443
+ f"[BaseModel-compile Error] Number of loss functions ({len(effective_loss)}) must match number of tasks ({self.nums_task})."
444
+ )
445
+ loss_list = list(effective_loss)
446
+ else:
447
+ loss_list = [effective_loss] * self.nums_task
448
+
449
+ for idx, mode in enumerate(self.training_modes):
450
+ if isinstance(loss_list[idx], str) and loss_list[idx] in {
451
+ "bce",
452
+ "binary_crossentropy",
453
+ }:
454
+ if mode in {"pairwise", "listwise"}:
455
+ loss_list[idx] = default_losses[mode]
400
456
  if loss_params is None:
401
457
  self.loss_params = {}
402
458
  else:
@@ -426,16 +482,8 @@ class BaseModel(FeatureSet, nn.Module):
426
482
  else None
427
483
  )
428
484
 
429
- self.loss_config = loss
485
+ self.loss_config = loss_list if self.nums_task > 1 else loss_list[0]
430
486
  self.loss_params = loss_params or {}
431
- if isinstance(loss, list):
432
- if len(loss) != self.nums_task:
433
- raise ValueError(
434
- f"[BaseModel-compile Error] Number of loss functions ({len(loss)}) must match number of tasks ({self.nums_task})."
435
- )
436
- loss_list = list(loss)
437
- else:
438
- loss_list = [loss] * self.nums_task
439
487
  if isinstance(self.loss_params, dict):
440
488
  loss_params_list = [self.loss_params] * self.nums_task
441
489
  else:
@@ -456,7 +504,7 @@ class BaseModel(FeatureSet, nn.Module):
456
504
  "[BaseModel-compile Error] GradNorm requires multi-task setup."
457
505
  )
458
506
  self.grad_norm = GradNormLossWeighting(
459
- num_tasks=self.nums_task, device=self.device
507
+ nums_task=self.nums_task, device=self.device
460
508
  )
461
509
  self.loss_weights = None
462
510
  elif (
@@ -469,7 +517,7 @@ class BaseModel(FeatureSet, nn.Module):
469
517
  grad_norm_params = dict(loss_weights)
470
518
  grad_norm_params.pop("method", None)
471
519
  self.grad_norm = GradNormLossWeighting(
472
- num_tasks=self.nums_task, device=self.device, **grad_norm_params
520
+ nums_task=self.nums_task, device=self.device, **grad_norm_params
473
521
  )
474
522
  self.loss_weights = None
475
523
  elif loss_weights is None:
@@ -507,6 +555,7 @@ class BaseModel(FeatureSet, nn.Module):
507
555
  raise ValueError(
508
556
  "[BaseModel-compute_loss Error] Ground truth labels (y_true) are required."
509
557
  )
558
+ # single-task
510
559
  if self.nums_task == 1:
511
560
  if y_pred.dim() == 1:
512
561
  y_pred = y_pred.view(-1, 1)
@@ -514,16 +563,30 @@ class BaseModel(FeatureSet, nn.Module):
514
563
  y_true = y_true.view(-1, 1)
515
564
  if y_pred.shape != y_true.shape:
516
565
  raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
566
+ loss_fn = self.loss_fn[0] if getattr(self, "loss_fn", None) else None
567
+ if loss_fn is None:
568
+ raise ValueError(
569
+ "[BaseModel-compute_loss Error] Loss function is not configured. Call compile() first."
570
+ )
571
+ mode = self.training_modes[0]
517
572
  task_dim = (
518
573
  self.task_dims[0] if hasattr(self, "task_dims") else y_pred.shape[1] # type: ignore
519
574
  )
520
- if task_dim == 1:
521
- loss = self.loss_fn[0](y_pred.view(-1), y_true.view(-1))
575
+ if mode in {"pairwise", "listwise"}:
576
+ loss = compute_ranking_loss(
577
+ training_mode=mode,
578
+ loss_fn=loss_fn,
579
+ y_pred=y_pred,
580
+ y_true=y_true,
581
+ )
582
+ elif task_dim == 1:
583
+ loss = loss_fn(y_pred.view(-1), y_true.view(-1))
522
584
  else:
523
- loss = self.loss_fn[0](y_pred, y_true)
585
+ loss = loss_fn(y_pred, y_true)
524
586
  if self.loss_weights is not None:
525
587
  loss *= self.loss_weights[0]
526
588
  return loss
589
+
527
590
  # multi-task
528
591
  if y_pred.shape != y_true.shape:
529
592
  raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
@@ -536,7 +599,16 @@ class BaseModel(FeatureSet, nn.Module):
536
599
  for i, (start, end) in enumerate(slices): # type: ignore
537
600
  y_pred_i = y_pred[:, start:end]
538
601
  y_true_i = y_true[:, start:end]
539
- task_loss = self.loss_fn[i](y_pred_i, y_true_i)
602
+ mode = self.training_modes[i]
603
+ if mode in {"pairwise", "listwise"}:
604
+ task_loss = compute_ranking_loss(
605
+ training_mode=mode,
606
+ loss_fn=self.loss_fn[i],
607
+ y_pred=y_pred_i,
608
+ y_true=y_true_i,
609
+ )
610
+ else:
611
+ task_loss = self.loss_fn[i](y_pred_i, y_true_i)
540
612
  task_losses.append(task_loss)
541
613
  if self.grad_norm is not None:
542
614
  if self.grad_norm_shared_params is None:
@@ -602,8 +674,8 @@ class BaseModel(FeatureSet, nn.Module):
602
674
  user_id_column: str | None = None,
603
675
  validation_split: float | None = None,
604
676
  num_workers: int = 0,
605
- tensorboard: bool = True,
606
- auto_distributed_sampler: bool = True,
677
+ use_tensorboard: bool = True,
678
+ auto_ddp_sampler: bool = True,
607
679
  log_interval: int = 1,
608
680
  ):
609
681
  """
@@ -619,8 +691,8 @@ class BaseModel(FeatureSet, nn.Module):
619
691
  user_id_column: Column name for GAUC-style metrics;.
620
692
  validation_split: Ratio to split training data when valid_data is None.
621
693
  num_workers: DataLoader worker count.
622
- tensorboard: Enable tensorboard logging.
623
- auto_distributed_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
694
+ use_tensorboard: Enable tensorboard logging.
695
+ auto_ddp_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
624
696
  log_interval: Log validation metrics every N epochs (still computes metrics each epoch).
625
697
 
626
698
  Notes:
@@ -662,7 +734,7 @@ class BaseModel(FeatureSet, nn.Module):
662
734
  setup_logger(session_id=self.session_id)
663
735
  self.logger_initialized = True
664
736
  self.training_logger = (
665
- TrainingLogger(session=self.session, enable_tensorboard=tensorboard)
737
+ TrainingLogger(session=self.session, use_tensorboard=use_tensorboard)
666
738
  if self.is_main_process
667
739
  else None
668
740
  )
@@ -680,18 +752,21 @@ class BaseModel(FeatureSet, nn.Module):
680
752
  if self.nums_task == 1:
681
753
  monitor_metric = f"val_{self.metrics[0]}"
682
754
  else:
683
- monitor_metric = f"val_{self.metrics[0]}_{self.target_columns[0]}"
755
+ # Determine which task to monitor for early stopping
756
+ monitor_task = self.early_stop_monitor_task
757
+ if monitor_task is None:
758
+ monitor_task = self.target_columns[0]
759
+ elif monitor_task not in self.target_columns:
760
+ raise ValueError(
761
+ f"[BaseModel-fit Error] early_stop_monitor_task '{monitor_task}' not found in target_columns {self.target_columns}."
762
+ )
763
+ monitor_metric = f"val_{self.metrics[0]}_{monitor_task}"
684
764
 
685
765
  existing_callbacks = self.callbacks.callbacks
686
- has_early_stop = any(isinstance(cb, EarlyStopper) for cb in existing_callbacks)
687
- has_checkpoint = any(
688
- isinstance(cb, CheckpointSaver) for cb in existing_callbacks
689
- )
690
- has_lr_scheduler = any(
691
- isinstance(cb, LearningRateScheduler) for cb in existing_callbacks
692
- )
693
766
 
694
- if self.early_stop_patience > 0 and not has_early_stop:
767
+ if self.early_stop_patience > 0 and not any(
768
+ isinstance(cb, EarlyStopper) for cb in existing_callbacks
769
+ ):
695
770
  self.callbacks.append(
696
771
  EarlyStopper(
697
772
  monitor=monitor_metric,
@@ -702,7 +777,9 @@ class BaseModel(FeatureSet, nn.Module):
702
777
  )
703
778
  )
704
779
 
705
- if self.is_main_process and not has_checkpoint:
780
+ if self.is_main_process and not any(
781
+ isinstance(cb, CheckpointSaver) for cb in existing_callbacks
782
+ ):
706
783
  self.callbacks.append(
707
784
  CheckpointSaver(
708
785
  best_path=self.best_path,
@@ -714,7 +791,9 @@ class BaseModel(FeatureSet, nn.Module):
714
791
  )
715
792
  )
716
793
 
717
- if self.scheduler_fn is not None and not has_lr_scheduler:
794
+ if self.scheduler_fn is not None and not any(
795
+ isinstance(cb, LearningRateScheduler) for cb in existing_callbacks
796
+ ):
718
797
  self.callbacks.append(
719
798
  LearningRateScheduler(
720
799
  scheduler=self.scheduler_fn,
@@ -737,16 +816,16 @@ class BaseModel(FeatureSet, nn.Module):
737
816
  self.stop_training = False
738
817
  self.best_checkpoint_path = self.best_path
739
818
  use_ddp_sampler = (
740
- auto_distributed_sampler
819
+ auto_ddp_sampler
741
820
  and self.distributed
742
821
  and dist.is_available()
743
822
  and dist.is_initialized()
744
823
  )
745
824
 
746
- if not auto_distributed_sampler and self.distributed and self.is_main_process:
825
+ if not auto_ddp_sampler and self.distributed and self.is_main_process:
747
826
  logging.info(
748
827
  colorize(
749
- "[Distributed Info] auto_distributed_sampler=False; assuming data is already sharded per rank.",
828
+ "[Distributed Info] auto_ddp_sampler=False; assuming data is already sharded per rank.",
750
829
  color="yellow",
751
830
  )
752
831
  )
@@ -825,12 +904,12 @@ class BaseModel(FeatureSet, nn.Module):
825
904
  # If split-based loader was built without sampler, attach here when enabled
826
905
  if (
827
906
  self.distributed
828
- and auto_distributed_sampler
907
+ and auto_ddp_sampler
829
908
  and isinstance(train_loader, DataLoader)
830
909
  and train_sampler is None
831
910
  ):
832
911
  raise NotImplementedError(
833
- "[BaseModel-fit Error] auto_distributed_sampler with pre-defined DataLoader is not supported yet."
912
+ "[BaseModel-fit Error] auto_ddp_sampler with pre-defined DataLoader is not supported yet."
834
913
  )
835
914
  # train_loader, train_sampler = add_distributed_sampler(train_loader, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
836
915
 
@@ -840,7 +919,7 @@ class BaseModel(FeatureSet, nn.Module):
840
919
  needs_user_ids=self.needs_user_ids,
841
920
  user_id_column=user_id_column,
842
921
  num_workers=num_workers,
843
- auto_distributed_sampler=auto_distributed_sampler,
922
+ auto_ddp_sampler=auto_ddp_sampler,
844
923
  )
845
924
  try:
846
925
  self.steps_per_epoch = len(train_loader)
@@ -862,7 +941,7 @@ class BaseModel(FeatureSet, nn.Module):
862
941
  logging.info("")
863
942
  tb_dir = (
864
943
  self.training_logger.tensorboard_logdir
865
- if self.training_logger and self.training_logger.enable_tensorboard
944
+ if self.training_logger and self.training_logger.use_tensorboard
866
945
  else None
867
946
  )
868
947
  if tb_dir:
@@ -1054,7 +1133,7 @@ class BaseModel(FeatureSet, nn.Module):
1054
1133
  y_true_list = []
1055
1134
  y_pred_list = []
1056
1135
  collect_metrics = getattr(self, "collect_train_metrics", True)
1057
- max_samples = getattr(self, "max_metrics_samples", None)
1136
+ max_samples = getattr(self, "metrics_sample_limit", None)
1058
1137
  collected_samples = 0
1059
1138
  metrics_capped = False
1060
1139
 
@@ -1183,14 +1262,14 @@ class BaseModel(FeatureSet, nn.Module):
1183
1262
  needs_user_ids: bool,
1184
1263
  user_id_column: str | None = "user_id",
1185
1264
  num_workers: int = 0,
1186
- auto_distributed_sampler: bool = True,
1265
+ auto_ddp_sampler: bool = True,
1187
1266
  ) -> tuple[DataLoader | None, np.ndarray | None]:
1188
1267
  if valid_data is None:
1189
1268
  return None, None
1190
1269
  if isinstance(valid_data, DataLoader):
1191
- if auto_distributed_sampler and self.distributed:
1270
+ if auto_ddp_sampler and self.distributed:
1192
1271
  raise NotImplementedError(
1193
- "[BaseModel-prepare_validation_data Error] auto_distributed_sampler with pre-defined DataLoader is not supported yet."
1272
+ "[BaseModel-prepare_validation_data Error] auto_ddp_sampler with pre-defined DataLoader is not supported yet."
1194
1273
  )
1195
1274
  # valid_loader, _ = add_distributed_sampler(valid_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=False, drop_last=False, default_batch_size=batch_size, is_main_process=self.is_main_process)
1196
1275
  else:
@@ -1199,7 +1278,7 @@ class BaseModel(FeatureSet, nn.Module):
1199
1278
  valid_sampler = None
1200
1279
  valid_loader, valid_dataset = self.prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, return_dataset=True) # type: ignore
1201
1280
  if (
1202
- auto_distributed_sampler
1281
+ auto_ddp_sampler
1203
1282
  and self.distributed
1204
1283
  and valid_dataset is not None
1205
1284
  and dist.is_available()
@@ -1372,11 +1451,11 @@ class BaseModel(FeatureSet, nn.Module):
1372
1451
  data: str | dict | pd.DataFrame | DataLoader,
1373
1452
  batch_size: int = 32,
1374
1453
  save_path: str | os.PathLike | None = None,
1375
- save_format: Literal["csv", "parquet"] = "csv",
1454
+ save_format: str = "csv",
1376
1455
  include_ids: bool | None = None,
1377
1456
  id_columns: str | list[str] | None = None,
1378
1457
  return_dataframe: bool = True,
1379
- streaming_chunk_size: int = 10000,
1458
+ stream_chunk_size: int = 10000,
1380
1459
  num_workers: int = 0,
1381
1460
  ) -> pd.DataFrame | np.ndarray | Path | None:
1382
1461
  """
@@ -1391,7 +1470,7 @@ class BaseModel(FeatureSet, nn.Module):
1391
1470
  include_ids: Whether to include ID columns in the output; if None, includes if id_columns are set.
1392
1471
  id_columns: Column name(s) to use as IDs; if None, uses model's id_columns.
1393
1472
  return_dataframe: Whether to return predictions as a pandas DataFrame; if False, returns a NumPy array.
1394
- streaming_chunk_size: Number of rows per chunk when using streaming mode for large datasets.
1473
+ stream_chunk_size: Number of rows per chunk when using streaming mode for large datasets.
1395
1474
  num_workers: DataLoader worker count.
1396
1475
  """
1397
1476
  self.eval()
@@ -1412,7 +1491,7 @@ class BaseModel(FeatureSet, nn.Module):
1412
1491
  save_path=save_path,
1413
1492
  save_format=save_format,
1414
1493
  include_ids=include_ids,
1415
- streaming_chunk_size=streaming_chunk_size,
1494
+ stream_chunk_size=stream_chunk_size,
1416
1495
  return_dataframe=return_dataframe,
1417
1496
  id_columns=predict_id_columns,
1418
1497
  )
@@ -1438,7 +1517,7 @@ class BaseModel(FeatureSet, nn.Module):
1438
1517
  batch_size=batch_size,
1439
1518
  shuffle=False,
1440
1519
  streaming=True,
1441
- chunk_size=streaming_chunk_size,
1520
+ chunk_size=stream_chunk_size,
1442
1521
  )
1443
1522
  else:
1444
1523
  data_loader = self.prepare_data_loader(
@@ -1516,11 +1595,18 @@ class BaseModel(FeatureSet, nn.Module):
1516
1595
  else y_pred_all
1517
1596
  )
1518
1597
  if save_path is not None:
1519
- if save_format not in ("csv", "parquet"):
1520
- raise ValueError(
1521
- f"[BaseModel-predict Error] Unsupported save_format '{save_format}'. Choose from 'csv' or 'parquet'."
1598
+ # Check streaming write support
1599
+ if not check_streaming_support(save_format):
1600
+ logging.warning(
1601
+ f"[BaseModel-predict Warning] Format '{save_format}' does not support streaming writes. "
1602
+ "The entire result will be saved at once. Use csv or parquet for large datasets."
1522
1603
  )
1523
- suffix = ".csv" if save_format == "csv" else ".parquet"
1604
+
1605
+ # Get file extension from format
1606
+ from nextrec.utils.data import FILE_FORMAT_CONFIG
1607
+
1608
+ suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
1609
+
1524
1610
  target_path = resolve_save_path(
1525
1611
  path=save_path,
1526
1612
  default_dir=self.session.predictions_dir,
@@ -1539,10 +1625,21 @@ class BaseModel(FeatureSet, nn.Module):
1539
1625
  f"[BaseModel-predict Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_to_save)})."
1540
1626
  )
1541
1627
  df_to_save = pd.concat([id_df, df_to_save], axis=1)
1628
+
1629
+ # Save based on format
1542
1630
  if save_format == "csv":
1543
1631
  df_to_save.to_csv(target_path, index=False)
1544
- else:
1632
+ elif save_format == "parquet":
1545
1633
  df_to_save.to_parquet(target_path, index=False)
1634
+ elif save_format == "feather":
1635
+ df_to_save.to_feather(target_path)
1636
+ elif save_format == "excel":
1637
+ df_to_save.to_excel(target_path, index=False)
1638
+ elif save_format == "hdf5":
1639
+ df_to_save.to_hdf(target_path, key="predictions", mode="w")
1640
+ else:
1641
+ raise ValueError(f"Unsupported save format: {save_format}")
1642
+
1546
1643
  logging.info(
1547
1644
  colorize(f"Predictions saved to: {target_path}", color="green")
1548
1645
  )
@@ -1553,9 +1650,9 @@ class BaseModel(FeatureSet, nn.Module):
1553
1650
  data: str | dict | pd.DataFrame | DataLoader,
1554
1651
  batch_size: int,
1555
1652
  save_path: str | os.PathLike,
1556
- save_format: Literal["csv", "parquet"],
1653
+ save_format: str,
1557
1654
  include_ids: bool,
1558
- streaming_chunk_size: int,
1655
+ stream_chunk_size: int,
1559
1656
  return_dataframe: bool,
1560
1657
  id_columns: list[str] | None = None,
1561
1658
  ) -> pd.DataFrame | Path:
@@ -1572,7 +1669,7 @@ class BaseModel(FeatureSet, nn.Module):
1572
1669
  batch_size=batch_size,
1573
1670
  shuffle=False,
1574
1671
  streaming=True,
1575
- chunk_size=streaming_chunk_size,
1672
+ chunk_size=stream_chunk_size,
1576
1673
  )
1577
1674
  elif not isinstance(data, DataLoader):
1578
1675
  data_loader = self.prepare_data_loader(
@@ -1594,7 +1691,17 @@ class BaseModel(FeatureSet, nn.Module):
1594
1691
  "When using streaming mode, set num_workers=0 to avoid reading data multiple times."
1595
1692
  )
1596
1693
 
1597
- suffix = ".csv" if save_format == "csv" else ".parquet"
1694
+ # Check streaming support and prepare file path
1695
+ if not check_streaming_support(save_format):
1696
+ logging.warning(
1697
+ f"[Predict Streaming Warning] Format '{save_format}' does not support streaming writes. "
1698
+ "Results will be collected in memory and saved at the end. Use csv or parquet for true streaming."
1699
+ )
1700
+
1701
+ from nextrec.utils.data import FILE_FORMAT_CONFIG
1702
+
1703
+ suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
1704
+
1598
1705
  target_path = resolve_save_path(
1599
1706
  path=save_path,
1600
1707
  default_dir=self.session.predictions_dir,
@@ -1605,9 +1712,10 @@ class BaseModel(FeatureSet, nn.Module):
1605
1712
  target_path.parent.mkdir(parents=True, exist_ok=True)
1606
1713
  header_written = target_path.exists() and target_path.stat().st_size > 0
1607
1714
  parquet_writer = None
1608
-
1609
1715
  pred_columns = None
1610
- collected_frames = [] # only used when return_dataframe is True
1716
+ collected_frames = (
1717
+ []
1718
+ ) # used when return_dataframe=True or for non-streaming formats
1611
1719
 
1612
1720
  with torch.no_grad():
1613
1721
  for batch_data in progress(data_loader, description="Predicting"):
@@ -1649,27 +1757,48 @@ class BaseModel(FeatureSet, nn.Module):
1649
1757
  )
1650
1758
  df_batch = pd.concat([id_df, df_batch], axis=1)
1651
1759
 
1760
+ # Streaming save based on format
1652
1761
  if save_format == "csv":
1653
1762
  df_batch.to_csv(
1654
1763
  target_path, mode="a", header=not header_written, index=False
1655
1764
  )
1656
1765
  header_written = True
1657
- else:
1766
+ elif save_format == "parquet":
1658
1767
  try:
1659
1768
  import pyarrow as pa
1660
1769
  import pyarrow.parquet as pq
1661
1770
  except ImportError as exc: # pragma: no cover
1662
1771
  raise ImportError(
1663
- "[BaseModel-predict-streaming Error] Parquet streaming save requires pyarrow to be installed."
1772
+ "[BaseModel-predict-streaming Error] Parquet streaming save requires pyarrow."
1664
1773
  ) from exc
1665
1774
  table = pa.Table.from_pandas(df_batch, preserve_index=False)
1666
1775
  if parquet_writer is None:
1667
1776
  parquet_writer = pq.ParquetWriter(target_path, table.schema)
1668
1777
  parquet_writer.write_table(table)
1669
- if return_dataframe:
1778
+ else:
1779
+ # Non-streaming formats: collect all data
1670
1780
  collected_frames.append(df_batch)
1781
+
1782
+ if return_dataframe:
1783
+ if (
1784
+ save_format in ["csv", "parquet"]
1785
+ and df_batch not in collected_frames
1786
+ ):
1787
+ collected_frames.append(df_batch)
1788
+
1789
+ # Close writers
1671
1790
  if parquet_writer is not None:
1672
1791
  parquet_writer.close()
1792
+ # For non-streaming formats, save collected data
1793
+ if save_format in ["feather", "excel", "hdf5"] and collected_frames:
1794
+ combined_df = pd.concat(collected_frames, ignore_index=True)
1795
+ if save_format == "feather":
1796
+ combined_df.to_feather(target_path)
1797
+ elif save_format == "excel":
1798
+ combined_df.to_excel(target_path, index=False)
1799
+ elif save_format == "hdf5":
1800
+ combined_df.to_hdf(target_path, key="predictions", mode="w")
1801
+
1673
1802
  logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
1674
1803
  if return_dataframe:
1675
1804
  return (
@@ -1690,7 +1819,7 @@ class BaseModel(FeatureSet, nn.Module):
1690
1819
  target_path = resolve_save_path(
1691
1820
  path=save_path,
1692
1821
  default_dir=self.session_path,
1693
- default_name=self.model_name,
1822
+ default_name=self.model_name.upper(),
1694
1823
  suffix=".pt",
1695
1824
  add_timestamp=add_timestamp,
1696
1825
  )
@@ -1844,7 +1973,9 @@ class BaseModel(FeatureSet, nn.Module):
1844
1973
  logger.info("")
1845
1974
  logger.info(
1846
1975
  colorize(
1847
- f"Model Summary: {self.model_name}", color="bright_blue", bold=True
1976
+ f"Model Summary: {self.model_name.upper()}",
1977
+ color="bright_blue",
1978
+ bold=True,
1848
1979
  )
1849
1980
  )
1850
1981
  logger.info("")
@@ -1975,7 +2106,7 @@ class BaseModel(FeatureSet, nn.Module):
1975
2106
  logger.info("Other Settings:")
1976
2107
  logger.info(f" Early Stop Patience: {self.early_stop_patience}")
1977
2108
  logger.info(f" Max Gradient Norm: {self.max_gradient_norm}")
1978
- logger.info(f" Max Metrics Samples: {self.max_metrics_samples}")
2109
+ logger.info(f" Max Metrics Samples: {self.metrics_sample_limit}")
1979
2110
  logger.info(f" Session ID: {self.session_id}")
1980
2111
  logger.info(f" Features Config Path: {self.features_config_path}")
1981
2112
  logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
@@ -2115,6 +2246,12 @@ class BaseMatchModel(BaseModel):
2115
2246
  )
2116
2247
  self.user_feature_names = {feature.name for feature in self.user_features_all}
2117
2248
  self.item_feature_names = {feature.name for feature in self.item_features_all}
2249
+ self.head = RetrievalHead(
2250
+ similarity_metric=self.similarity_metric,
2251
+ temperature=self.temperature,
2252
+ training_mode=self.training_mode,
2253
+ apply_sigmoid=True,
2254
+ )
2118
2255
 
2119
2256
  def compile(
2120
2257
  self,
@@ -2139,7 +2276,7 @@ class BaseMatchModel(BaseModel):
2139
2276
  """
2140
2277
  if self.training_mode not in self.support_training_modes:
2141
2278
  raise ValueError(
2142
- f"{self.model_name} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
2279
+ f"{self.model_name.upper()} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
2143
2280
  )
2144
2281
 
2145
2282
  default_loss_by_mode: dict[str, str] = {
@@ -2244,15 +2381,7 @@ class BaseMatchModel(BaseModel):
2244
2381
  user_emb = self.user_tower(user_input) # [B, D]
2245
2382
  item_emb = self.item_tower(item_input) # [B, D]
2246
2383
 
2247
- if self.training and self.training_mode in ["pairwise", "listwise"]:
2248
- return user_emb, item_emb
2249
-
2250
- similarity = self.compute_similarity(user_emb, item_emb) # [B]
2251
-
2252
- if self.training_mode == "pointwise":
2253
- return torch.sigmoid(similarity)
2254
- else:
2255
- return similarity
2384
+ return self.head(user_emb, item_emb, similarity_fn=self.compute_similarity)
2256
2385
 
2257
2386
  def compute_loss(self, y_pred, y_true):
2258
2387
  if self.training_mode == "pointwise":
@@ -2308,7 +2437,7 @@ class BaseMatchModel(BaseModel):
2308
2437
  features: list,
2309
2438
  batch_size: int,
2310
2439
  num_workers: int = 0,
2311
- streaming_chunk_size: int = 10000,
2440
+ stream_chunk_size: int = 10000,
2312
2441
  ) -> DataLoader:
2313
2442
  """Prepare data loader for specific features."""
2314
2443
  if isinstance(data, DataLoader):
@@ -2329,7 +2458,7 @@ class BaseMatchModel(BaseModel):
2329
2458
  batch_size=batch_size,
2330
2459
  shuffle=False,
2331
2460
  streaming=True,
2332
- chunk_size=streaming_chunk_size,
2461
+ chunk_size=stream_chunk_size,
2333
2462
  num_workers=num_workers,
2334
2463
  )
2335
2464
  tensors = build_tensors_from_data(
@@ -2382,7 +2511,7 @@ class BaseMatchModel(BaseModel):
2382
2511
  ),
2383
2512
  batch_size: int = 512,
2384
2513
  num_workers: int = 0,
2385
- streaming_chunk_size: int = 10000,
2514
+ stream_chunk_size: int = 10000,
2386
2515
  ) -> np.ndarray:
2387
2516
  self.eval()
2388
2517
  data_loader = self.prepare_feature_data(
@@ -2390,7 +2519,7 @@ class BaseMatchModel(BaseModel):
2390
2519
  self.user_features_all,
2391
2520
  batch_size,
2392
2521
  num_workers=num_workers,
2393
- streaming_chunk_size=streaming_chunk_size,
2522
+ stream_chunk_size=stream_chunk_size,
2394
2523
  )
2395
2524
 
2396
2525
  embeddings_list = []
@@ -2416,7 +2545,7 @@ class BaseMatchModel(BaseModel):
2416
2545
  ),
2417
2546
  batch_size: int = 512,
2418
2547
  num_workers: int = 0,
2419
- streaming_chunk_size: int = 10000,
2548
+ stream_chunk_size: int = 10000,
2420
2549
  ) -> np.ndarray:
2421
2550
  self.eval()
2422
2551
  data_loader = self.prepare_feature_data(
@@ -2424,7 +2553,7 @@ class BaseMatchModel(BaseModel):
2424
2553
  self.item_features_all,
2425
2554
  batch_size,
2426
2555
  num_workers=num_workers,
2427
- streaming_chunk_size=streaming_chunk_size,
2556
+ stream_chunk_size=stream_chunk_size,
2428
2557
  )
2429
2558
 
2430
2559
  embeddings_list = []