nextrec 0.4.17__py3-none-any.whl → 0.4.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/basic/model.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 20/12/2025
5
+ Checkpoint: edit on 24/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
@@ -49,6 +49,7 @@ from nextrec.data.dataloader import (
49
49
  TensorDictDataset,
50
50
  build_tensors_from_data,
51
51
  )
52
+ from nextrec.utils.data import check_streaming_support
52
53
  from nextrec.loss import (
53
54
  BPRLoss,
54
55
  GradNormLossWeighting,
@@ -69,6 +70,7 @@ from nextrec.utils.torch_utils import (
69
70
  init_process_group,
70
71
  to_tensor,
71
72
  )
73
+ from nextrec.utils.model import compute_ranking_loss
72
74
 
73
75
 
74
76
  class BaseModel(FeatureSet, nn.Module):
@@ -88,13 +90,18 @@ class BaseModel(FeatureSet, nn.Module):
88
90
  target: list[str] | str | None = None,
89
91
  id_columns: list[str] | str | None = None,
90
92
  task: str | list[str] | None = None,
93
+ training_mode: (
94
+ Literal["pointwise", "pairwise", "listwise"]
95
+ | list[Literal["pointwise", "pairwise", "listwise"]]
96
+ ) = "pointwise",
91
97
  embedding_l1_reg: float = 0.0,
92
98
  dense_l1_reg: float = 0.0,
93
99
  embedding_l2_reg: float = 0.0,
94
100
  dense_l2_reg: float = 0.0,
95
101
  device: str = "cpu",
96
102
  early_stop_patience: int = 20,
97
- max_metrics_samples: int | None = 200000,
103
+ early_stop_monitor_task: str | None = None,
104
+ metrics_sample_limit: int | None = 200000,
98
105
  session_id: str | None = None,
99
106
  callbacks: list[Callback] | None = None,
100
107
  distributed: bool = False,
@@ -113,6 +120,7 @@ class BaseModel(FeatureSet, nn.Module):
113
120
  target: Target column name. e.g., 'label' or ['label1', 'label2'].
114
121
  id_columns: Identifier column name, only need to specify if GAUC is required. e.g., 'user_id'.
115
122
  task: Task types, e.g., 'binary', 'regression', or ['binary', 'regression']. If None, falls back to self.default_task.
123
+ training_mode: Training mode for ranking tasks; a single mode or a list per task.
116
124
 
117
125
  embedding_l1_reg: L1 regularization strength for embedding params. e.g., 1e-6.
118
126
  dense_l1_reg: L1 regularization strength for dense params. e.g., 1e-5.
@@ -121,7 +129,8 @@ class BaseModel(FeatureSet, nn.Module):
121
129
 
122
130
  device: Torch device string or torch.device. e.g., 'cpu', 'cuda:0'.
123
131
  early_stop_patience: Epochs for early stopping. 0 to disable. e.g., 20.
124
- max_metrics_samples: Max samples to keep for training metrics. None disables limit.
132
+ early_stop_monitor_task: Task name to monitor for early stopping in multi-task scenario. If None, uses first target. e.g., 'click'.
133
+ metrics_sample_limit: Max samples to keep for training metrics. None disables limit.
125
134
  session_id: Session id for logging. If None, a default id with timestamps will be created. e.g., 'session_tutorial'.
126
135
  callbacks: List of callback instances. If None, default callbacks will be created. e.g., [EarlyStopper(), CheckpointSaver()].
127
136
 
@@ -150,9 +159,11 @@ class BaseModel(FeatureSet, nn.Module):
150
159
  self.session = create_session(session_id)
151
160
  self.session_path = self.session.root # pwd/session_id, path for this session
152
161
  self.checkpoint_path = os.path.join(
153
- self.session_path, self.model_name + "_checkpoint.pt"
162
+ self.session_path, self.model_name.upper() + "_checkpoint.pt"
154
163
  ) # e.g., pwd/session_id/DeepFM_checkpoint.pt
155
- self.best_path = os.path.join(self.session_path, self.model_name + "_best.pt")
164
+ self.best_path = os.path.join(
165
+ self.session_path, self.model_name.upper() + "_best.pt"
166
+ )
156
167
  self.features_config_path = os.path.join(
157
168
  self.session_path, "features_config.pkl"
158
169
  )
@@ -162,6 +173,22 @@ class BaseModel(FeatureSet, nn.Module):
162
173
 
163
174
  self.task = self.default_task if task is None else task
164
175
  self.nums_task = len(self.task) if isinstance(self.task, list) else 1
176
+ if isinstance(training_mode, list):
177
+ if len(training_mode) != self.nums_task:
178
+ raise ValueError(
179
+ "[BaseModel-init Error] training_mode list length must match number of tasks."
180
+ )
181
+ self.training_modes = list(training_mode)
182
+ else:
183
+ self.training_modes = [training_mode] * self.nums_task
184
+ for mode in self.training_modes:
185
+ if mode not in {"pointwise", "pairwise", "listwise"}:
186
+ raise ValueError(
187
+ "[BaseModel-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
188
+ )
189
+ self.training_mode = (
190
+ self.training_modes if self.nums_task > 1 else self.training_modes[0]
191
+ )
165
192
 
166
193
  self.embedding_l1_reg = embedding_l1_reg
167
194
  self.dense_l1_reg = dense_l1_reg
@@ -172,9 +199,10 @@ class BaseModel(FeatureSet, nn.Module):
172
199
  self.loss_weight = None
173
200
 
174
201
  self.early_stop_patience = early_stop_patience
202
+ self.early_stop_monitor_task = early_stop_monitor_task
175
203
  # max samples to keep for training metrics, in case of large training set
176
- self.max_metrics_samples = (
177
- None if max_metrics_samples is None else int(max_metrics_samples)
204
+ self.metrics_sample_limit = (
205
+ None if metrics_sample_limit is None else int(metrics_sample_limit)
178
206
  )
179
207
  self.max_gradient_norm = 1.0
180
208
  self.logger_initialized = False
@@ -398,6 +426,33 @@ class BaseModel(FeatureSet, nn.Module):
398
426
  Use "grad_norm" or {"method": "grad_norm", ...} to enable GradNorm for multi-task loss balancing.
399
427
  callbacks: Additional callbacks to add to the existing callback list. e.g., [EarlyStopper(), CheckpointSaver()].
400
428
  """
429
+ default_losses = {
430
+ "pointwise": "bce",
431
+ "pairwise": "bpr",
432
+ "listwise": "listnet",
433
+ }
434
+ effective_loss = loss
435
+ if effective_loss is None:
436
+ loss_list = [default_losses[mode] for mode in self.training_modes]
437
+ elif isinstance(effective_loss, list):
438
+ if not effective_loss:
439
+ loss_list = [default_losses[mode] for mode in self.training_modes]
440
+ else:
441
+ if len(effective_loss) != self.nums_task:
442
+ raise ValueError(
443
+ f"[BaseModel-compile Error] Number of loss functions ({len(effective_loss)}) must match number of tasks ({self.nums_task})."
444
+ )
445
+ loss_list = list(effective_loss)
446
+ else:
447
+ loss_list = [effective_loss] * self.nums_task
448
+
449
+ for idx, mode in enumerate(self.training_modes):
450
+ if isinstance(loss_list[idx], str) and loss_list[idx] in {
451
+ "bce",
452
+ "binary_crossentropy",
453
+ }:
454
+ if mode in {"pairwise", "listwise"}:
455
+ loss_list[idx] = default_losses[mode]
401
456
  if loss_params is None:
402
457
  self.loss_params = {}
403
458
  else:
@@ -427,16 +482,8 @@ class BaseModel(FeatureSet, nn.Module):
427
482
  else None
428
483
  )
429
484
 
430
- self.loss_config = loss
485
+ self.loss_config = loss_list if self.nums_task > 1 else loss_list[0]
431
486
  self.loss_params = loss_params or {}
432
- if isinstance(loss, list):
433
- if len(loss) != self.nums_task:
434
- raise ValueError(
435
- f"[BaseModel-compile Error] Number of loss functions ({len(loss)}) must match number of tasks ({self.nums_task})."
436
- )
437
- loss_list = list(loss)
438
- else:
439
- loss_list = [loss] * self.nums_task
440
487
  if isinstance(self.loss_params, dict):
441
488
  loss_params_list = [self.loss_params] * self.nums_task
442
489
  else:
@@ -457,7 +504,7 @@ class BaseModel(FeatureSet, nn.Module):
457
504
  "[BaseModel-compile Error] GradNorm requires multi-task setup."
458
505
  )
459
506
  self.grad_norm = GradNormLossWeighting(
460
- num_tasks=self.nums_task, device=self.device
507
+ nums_task=self.nums_task, device=self.device
461
508
  )
462
509
  self.loss_weights = None
463
510
  elif (
@@ -470,7 +517,7 @@ class BaseModel(FeatureSet, nn.Module):
470
517
  grad_norm_params = dict(loss_weights)
471
518
  grad_norm_params.pop("method", None)
472
519
  self.grad_norm = GradNormLossWeighting(
473
- num_tasks=self.nums_task, device=self.device, **grad_norm_params
520
+ nums_task=self.nums_task, device=self.device, **grad_norm_params
474
521
  )
475
522
  self.loss_weights = None
476
523
  elif loss_weights is None:
@@ -508,6 +555,7 @@ class BaseModel(FeatureSet, nn.Module):
508
555
  raise ValueError(
509
556
  "[BaseModel-compute_loss Error] Ground truth labels (y_true) are required."
510
557
  )
558
+ # single-task
511
559
  if self.nums_task == 1:
512
560
  if y_pred.dim() == 1:
513
561
  y_pred = y_pred.view(-1, 1)
@@ -515,16 +563,30 @@ class BaseModel(FeatureSet, nn.Module):
515
563
  y_true = y_true.view(-1, 1)
516
564
  if y_pred.shape != y_true.shape:
517
565
  raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
566
+ loss_fn = self.loss_fn[0] if getattr(self, "loss_fn", None) else None
567
+ if loss_fn is None:
568
+ raise ValueError(
569
+ "[BaseModel-compute_loss Error] Loss function is not configured. Call compile() first."
570
+ )
571
+ mode = self.training_modes[0]
518
572
  task_dim = (
519
573
  self.task_dims[0] if hasattr(self, "task_dims") else y_pred.shape[1] # type: ignore
520
574
  )
521
- if task_dim == 1:
522
- loss = self.loss_fn[0](y_pred.view(-1), y_true.view(-1))
575
+ if mode in {"pairwise", "listwise"}:
576
+ loss = compute_ranking_loss(
577
+ training_mode=mode,
578
+ loss_fn=loss_fn,
579
+ y_pred=y_pred,
580
+ y_true=y_true,
581
+ )
582
+ elif task_dim == 1:
583
+ loss = loss_fn(y_pred.view(-1), y_true.view(-1))
523
584
  else:
524
- loss = self.loss_fn[0](y_pred, y_true)
585
+ loss = loss_fn(y_pred, y_true)
525
586
  if self.loss_weights is not None:
526
587
  loss *= self.loss_weights[0]
527
588
  return loss
589
+
528
590
  # multi-task
529
591
  if y_pred.shape != y_true.shape:
530
592
  raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
@@ -537,7 +599,16 @@ class BaseModel(FeatureSet, nn.Module):
537
599
  for i, (start, end) in enumerate(slices): # type: ignore
538
600
  y_pred_i = y_pred[:, start:end]
539
601
  y_true_i = y_true[:, start:end]
540
- task_loss = self.loss_fn[i](y_pred_i, y_true_i)
602
+ mode = self.training_modes[i]
603
+ if mode in {"pairwise", "listwise"}:
604
+ task_loss = compute_ranking_loss(
605
+ training_mode=mode,
606
+ loss_fn=self.loss_fn[i],
607
+ y_pred=y_pred_i,
608
+ y_true=y_true_i,
609
+ )
610
+ else:
611
+ task_loss = self.loss_fn[i](y_pred_i, y_true_i)
541
612
  task_losses.append(task_loss)
542
613
  if self.grad_norm is not None:
543
614
  if self.grad_norm_shared_params is None:
@@ -603,8 +674,8 @@ class BaseModel(FeatureSet, nn.Module):
603
674
  user_id_column: str | None = None,
604
675
  validation_split: float | None = None,
605
676
  num_workers: int = 0,
606
- tensorboard: bool = True,
607
- auto_distributed_sampler: bool = True,
677
+ use_tensorboard: bool = True,
678
+ auto_ddp_sampler: bool = True,
608
679
  log_interval: int = 1,
609
680
  ):
610
681
  """
@@ -620,8 +691,8 @@ class BaseModel(FeatureSet, nn.Module):
620
691
  user_id_column: Column name for GAUC-style metrics;.
621
692
  validation_split: Ratio to split training data when valid_data is None.
622
693
  num_workers: DataLoader worker count.
623
- tensorboard: Enable tensorboard logging.
624
- auto_distributed_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
694
+ use_tensorboard: Enable tensorboard logging.
695
+ auto_ddp_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
625
696
  log_interval: Log validation metrics every N epochs (still computes metrics each epoch).
626
697
 
627
698
  Notes:
@@ -663,7 +734,7 @@ class BaseModel(FeatureSet, nn.Module):
663
734
  setup_logger(session_id=self.session_id)
664
735
  self.logger_initialized = True
665
736
  self.training_logger = (
666
- TrainingLogger(session=self.session, enable_tensorboard=tensorboard)
737
+ TrainingLogger(session=self.session, use_tensorboard=use_tensorboard)
667
738
  if self.is_main_process
668
739
  else None
669
740
  )
@@ -681,18 +752,21 @@ class BaseModel(FeatureSet, nn.Module):
681
752
  if self.nums_task == 1:
682
753
  monitor_metric = f"val_{self.metrics[0]}"
683
754
  else:
684
- monitor_metric = f"val_{self.metrics[0]}_{self.target_columns[0]}"
755
+ # Determine which task to monitor for early stopping
756
+ monitor_task = self.early_stop_monitor_task
757
+ if monitor_task is None:
758
+ monitor_task = self.target_columns[0]
759
+ elif monitor_task not in self.target_columns:
760
+ raise ValueError(
761
+ f"[BaseModel-fit Error] early_stop_monitor_task '{monitor_task}' not found in target_columns {self.target_columns}."
762
+ )
763
+ monitor_metric = f"val_{self.metrics[0]}_{monitor_task}"
685
764
 
686
765
  existing_callbacks = self.callbacks.callbacks
687
- has_early_stop = any(isinstance(cb, EarlyStopper) for cb in existing_callbacks)
688
- has_checkpoint = any(
689
- isinstance(cb, CheckpointSaver) for cb in existing_callbacks
690
- )
691
- has_lr_scheduler = any(
692
- isinstance(cb, LearningRateScheduler) for cb in existing_callbacks
693
- )
694
766
 
695
- if self.early_stop_patience > 0 and not has_early_stop:
767
+ if self.early_stop_patience > 0 and not any(
768
+ isinstance(cb, EarlyStopper) for cb in existing_callbacks
769
+ ):
696
770
  self.callbacks.append(
697
771
  EarlyStopper(
698
772
  monitor=monitor_metric,
@@ -703,7 +777,9 @@ class BaseModel(FeatureSet, nn.Module):
703
777
  )
704
778
  )
705
779
 
706
- if self.is_main_process and not has_checkpoint:
780
+ if self.is_main_process and not any(
781
+ isinstance(cb, CheckpointSaver) for cb in existing_callbacks
782
+ ):
707
783
  self.callbacks.append(
708
784
  CheckpointSaver(
709
785
  best_path=self.best_path,
@@ -715,7 +791,9 @@ class BaseModel(FeatureSet, nn.Module):
715
791
  )
716
792
  )
717
793
 
718
- if self.scheduler_fn is not None and not has_lr_scheduler:
794
+ if self.scheduler_fn is not None and not any(
795
+ isinstance(cb, LearningRateScheduler) for cb in existing_callbacks
796
+ ):
719
797
  self.callbacks.append(
720
798
  LearningRateScheduler(
721
799
  scheduler=self.scheduler_fn,
@@ -738,16 +816,16 @@ class BaseModel(FeatureSet, nn.Module):
738
816
  self.stop_training = False
739
817
  self.best_checkpoint_path = self.best_path
740
818
  use_ddp_sampler = (
741
- auto_distributed_sampler
819
+ auto_ddp_sampler
742
820
  and self.distributed
743
821
  and dist.is_available()
744
822
  and dist.is_initialized()
745
823
  )
746
824
 
747
- if not auto_distributed_sampler and self.distributed and self.is_main_process:
825
+ if not auto_ddp_sampler and self.distributed and self.is_main_process:
748
826
  logging.info(
749
827
  colorize(
750
- "[Distributed Info] auto_distributed_sampler=False; assuming data is already sharded per rank.",
828
+ "[Distributed Info] auto_ddp_sampler=False; assuming data is already sharded per rank.",
751
829
  color="yellow",
752
830
  )
753
831
  )
@@ -826,12 +904,12 @@ class BaseModel(FeatureSet, nn.Module):
826
904
  # If split-based loader was built without sampler, attach here when enabled
827
905
  if (
828
906
  self.distributed
829
- and auto_distributed_sampler
907
+ and auto_ddp_sampler
830
908
  and isinstance(train_loader, DataLoader)
831
909
  and train_sampler is None
832
910
  ):
833
911
  raise NotImplementedError(
834
- "[BaseModel-fit Error] auto_distributed_sampler with pre-defined DataLoader is not supported yet."
912
+ "[BaseModel-fit Error] auto_ddp_sampler with pre-defined DataLoader is not supported yet."
835
913
  )
836
914
  # train_loader, train_sampler = add_distributed_sampler(train_loader, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
837
915
 
@@ -841,7 +919,7 @@ class BaseModel(FeatureSet, nn.Module):
841
919
  needs_user_ids=self.needs_user_ids,
842
920
  user_id_column=user_id_column,
843
921
  num_workers=num_workers,
844
- auto_distributed_sampler=auto_distributed_sampler,
922
+ auto_ddp_sampler=auto_ddp_sampler,
845
923
  )
846
924
  try:
847
925
  self.steps_per_epoch = len(train_loader)
@@ -863,7 +941,7 @@ class BaseModel(FeatureSet, nn.Module):
863
941
  logging.info("")
864
942
  tb_dir = (
865
943
  self.training_logger.tensorboard_logdir
866
- if self.training_logger and self.training_logger.enable_tensorboard
944
+ if self.training_logger and self.training_logger.use_tensorboard
867
945
  else None
868
946
  )
869
947
  if tb_dir:
@@ -1055,7 +1133,7 @@ class BaseModel(FeatureSet, nn.Module):
1055
1133
  y_true_list = []
1056
1134
  y_pred_list = []
1057
1135
  collect_metrics = getattr(self, "collect_train_metrics", True)
1058
- max_samples = getattr(self, "max_metrics_samples", None)
1136
+ max_samples = getattr(self, "metrics_sample_limit", None)
1059
1137
  collected_samples = 0
1060
1138
  metrics_capped = False
1061
1139
 
@@ -1184,14 +1262,14 @@ class BaseModel(FeatureSet, nn.Module):
1184
1262
  needs_user_ids: bool,
1185
1263
  user_id_column: str | None = "user_id",
1186
1264
  num_workers: int = 0,
1187
- auto_distributed_sampler: bool = True,
1265
+ auto_ddp_sampler: bool = True,
1188
1266
  ) -> tuple[DataLoader | None, np.ndarray | None]:
1189
1267
  if valid_data is None:
1190
1268
  return None, None
1191
1269
  if isinstance(valid_data, DataLoader):
1192
- if auto_distributed_sampler and self.distributed:
1270
+ if auto_ddp_sampler and self.distributed:
1193
1271
  raise NotImplementedError(
1194
- "[BaseModel-prepare_validation_data Error] auto_distributed_sampler with pre-defined DataLoader is not supported yet."
1272
+ "[BaseModel-prepare_validation_data Error] auto_ddp_sampler with pre-defined DataLoader is not supported yet."
1195
1273
  )
1196
1274
  # valid_loader, _ = add_distributed_sampler(valid_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=False, drop_last=False, default_batch_size=batch_size, is_main_process=self.is_main_process)
1197
1275
  else:
@@ -1200,7 +1278,7 @@ class BaseModel(FeatureSet, nn.Module):
1200
1278
  valid_sampler = None
1201
1279
  valid_loader, valid_dataset = self.prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, return_dataset=True) # type: ignore
1202
1280
  if (
1203
- auto_distributed_sampler
1281
+ auto_ddp_sampler
1204
1282
  and self.distributed
1205
1283
  and valid_dataset is not None
1206
1284
  and dist.is_available()
@@ -1373,11 +1451,11 @@ class BaseModel(FeatureSet, nn.Module):
1373
1451
  data: str | dict | pd.DataFrame | DataLoader,
1374
1452
  batch_size: int = 32,
1375
1453
  save_path: str | os.PathLike | None = None,
1376
- save_format: Literal["csv", "parquet"] = "csv",
1454
+ save_format: str = "csv",
1377
1455
  include_ids: bool | None = None,
1378
1456
  id_columns: str | list[str] | None = None,
1379
1457
  return_dataframe: bool = True,
1380
- streaming_chunk_size: int = 10000,
1458
+ stream_chunk_size: int = 10000,
1381
1459
  num_workers: int = 0,
1382
1460
  ) -> pd.DataFrame | np.ndarray | Path | None:
1383
1461
  """
@@ -1392,7 +1470,7 @@ class BaseModel(FeatureSet, nn.Module):
1392
1470
  include_ids: Whether to include ID columns in the output; if None, includes if id_columns are set.
1393
1471
  id_columns: Column name(s) to use as IDs; if None, uses model's id_columns.
1394
1472
  return_dataframe: Whether to return predictions as a pandas DataFrame; if False, returns a NumPy array.
1395
- streaming_chunk_size: Number of rows per chunk when using streaming mode for large datasets.
1473
+ stream_chunk_size: Number of rows per chunk when using streaming mode for large datasets.
1396
1474
  num_workers: DataLoader worker count.
1397
1475
  """
1398
1476
  self.eval()
@@ -1413,7 +1491,7 @@ class BaseModel(FeatureSet, nn.Module):
1413
1491
  save_path=save_path,
1414
1492
  save_format=save_format,
1415
1493
  include_ids=include_ids,
1416
- streaming_chunk_size=streaming_chunk_size,
1494
+ stream_chunk_size=stream_chunk_size,
1417
1495
  return_dataframe=return_dataframe,
1418
1496
  id_columns=predict_id_columns,
1419
1497
  )
@@ -1439,7 +1517,7 @@ class BaseModel(FeatureSet, nn.Module):
1439
1517
  batch_size=batch_size,
1440
1518
  shuffle=False,
1441
1519
  streaming=True,
1442
- chunk_size=streaming_chunk_size,
1520
+ chunk_size=stream_chunk_size,
1443
1521
  )
1444
1522
  else:
1445
1523
  data_loader = self.prepare_data_loader(
@@ -1517,11 +1595,18 @@ class BaseModel(FeatureSet, nn.Module):
1517
1595
  else y_pred_all
1518
1596
  )
1519
1597
  if save_path is not None:
1520
- if save_format not in ("csv", "parquet"):
1521
- raise ValueError(
1522
- f"[BaseModel-predict Error] Unsupported save_format '{save_format}'. Choose from 'csv' or 'parquet'."
1598
+ # Check streaming write support
1599
+ if not check_streaming_support(save_format):
1600
+ logging.warning(
1601
+ f"[BaseModel-predict Warning] Format '{save_format}' does not support streaming writes. "
1602
+ "The entire result will be saved at once. Use csv or parquet for large datasets."
1523
1603
  )
1524
- suffix = ".csv" if save_format == "csv" else ".parquet"
1604
+
1605
+ # Get file extension from format
1606
+ from nextrec.utils.data import FILE_FORMAT_CONFIG
1607
+
1608
+ suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
1609
+
1525
1610
  target_path = resolve_save_path(
1526
1611
  path=save_path,
1527
1612
  default_dir=self.session.predictions_dir,
@@ -1540,10 +1625,21 @@ class BaseModel(FeatureSet, nn.Module):
1540
1625
  f"[BaseModel-predict Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_to_save)})."
1541
1626
  )
1542
1627
  df_to_save = pd.concat([id_df, df_to_save], axis=1)
1628
+
1629
+ # Save based on format
1543
1630
  if save_format == "csv":
1544
1631
  df_to_save.to_csv(target_path, index=False)
1545
- else:
1632
+ elif save_format == "parquet":
1546
1633
  df_to_save.to_parquet(target_path, index=False)
1634
+ elif save_format == "feather":
1635
+ df_to_save.to_feather(target_path)
1636
+ elif save_format == "excel":
1637
+ df_to_save.to_excel(target_path, index=False)
1638
+ elif save_format == "hdf5":
1639
+ df_to_save.to_hdf(target_path, key="predictions", mode="w")
1640
+ else:
1641
+ raise ValueError(f"Unsupported save format: {save_format}")
1642
+
1547
1643
  logging.info(
1548
1644
  colorize(f"Predictions saved to: {target_path}", color="green")
1549
1645
  )
@@ -1554,9 +1650,9 @@ class BaseModel(FeatureSet, nn.Module):
1554
1650
  data: str | dict | pd.DataFrame | DataLoader,
1555
1651
  batch_size: int,
1556
1652
  save_path: str | os.PathLike,
1557
- save_format: Literal["csv", "parquet"],
1653
+ save_format: str,
1558
1654
  include_ids: bool,
1559
- streaming_chunk_size: int,
1655
+ stream_chunk_size: int,
1560
1656
  return_dataframe: bool,
1561
1657
  id_columns: list[str] | None = None,
1562
1658
  ) -> pd.DataFrame | Path:
@@ -1573,7 +1669,7 @@ class BaseModel(FeatureSet, nn.Module):
1573
1669
  batch_size=batch_size,
1574
1670
  shuffle=False,
1575
1671
  streaming=True,
1576
- chunk_size=streaming_chunk_size,
1672
+ chunk_size=stream_chunk_size,
1577
1673
  )
1578
1674
  elif not isinstance(data, DataLoader):
1579
1675
  data_loader = self.prepare_data_loader(
@@ -1595,7 +1691,17 @@ class BaseModel(FeatureSet, nn.Module):
1595
1691
  "When using streaming mode, set num_workers=0 to avoid reading data multiple times."
1596
1692
  )
1597
1693
 
1598
- suffix = ".csv" if save_format == "csv" else ".parquet"
1694
+ # Check streaming support and prepare file path
1695
+ if not check_streaming_support(save_format):
1696
+ logging.warning(
1697
+ f"[Predict Streaming Warning] Format '{save_format}' does not support streaming writes. "
1698
+ "Results will be collected in memory and saved at the end. Use csv or parquet for true streaming."
1699
+ )
1700
+
1701
+ from nextrec.utils.data import FILE_FORMAT_CONFIG
1702
+
1703
+ suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
1704
+
1599
1705
  target_path = resolve_save_path(
1600
1706
  path=save_path,
1601
1707
  default_dir=self.session.predictions_dir,
@@ -1606,9 +1712,10 @@ class BaseModel(FeatureSet, nn.Module):
1606
1712
  target_path.parent.mkdir(parents=True, exist_ok=True)
1607
1713
  header_written = target_path.exists() and target_path.stat().st_size > 0
1608
1714
  parquet_writer = None
1609
-
1610
1715
  pred_columns = None
1611
- collected_frames = [] # only used when return_dataframe is True
1716
+ collected_frames = (
1717
+ []
1718
+ ) # used when return_dataframe=True or for non-streaming formats
1612
1719
 
1613
1720
  with torch.no_grad():
1614
1721
  for batch_data in progress(data_loader, description="Predicting"):
@@ -1650,27 +1757,48 @@ class BaseModel(FeatureSet, nn.Module):
1650
1757
  )
1651
1758
  df_batch = pd.concat([id_df, df_batch], axis=1)
1652
1759
 
1760
+ # Streaming save based on format
1653
1761
  if save_format == "csv":
1654
1762
  df_batch.to_csv(
1655
1763
  target_path, mode="a", header=not header_written, index=False
1656
1764
  )
1657
1765
  header_written = True
1658
- else:
1766
+ elif save_format == "parquet":
1659
1767
  try:
1660
1768
  import pyarrow as pa
1661
1769
  import pyarrow.parquet as pq
1662
1770
  except ImportError as exc: # pragma: no cover
1663
1771
  raise ImportError(
1664
- "[BaseModel-predict-streaming Error] Parquet streaming save requires pyarrow to be installed."
1772
+ "[BaseModel-predict-streaming Error] Parquet streaming save requires pyarrow."
1665
1773
  ) from exc
1666
1774
  table = pa.Table.from_pandas(df_batch, preserve_index=False)
1667
1775
  if parquet_writer is None:
1668
1776
  parquet_writer = pq.ParquetWriter(target_path, table.schema)
1669
1777
  parquet_writer.write_table(table)
1670
- if return_dataframe:
1778
+ else:
1779
+ # Non-streaming formats: collect all data
1671
1780
  collected_frames.append(df_batch)
1781
+
1782
+ if return_dataframe:
1783
+ if (
1784
+ save_format in ["csv", "parquet"]
1785
+ and df_batch not in collected_frames
1786
+ ):
1787
+ collected_frames.append(df_batch)
1788
+
1789
+ # Close writers
1672
1790
  if parquet_writer is not None:
1673
1791
  parquet_writer.close()
1792
+ # For non-streaming formats, save collected data
1793
+ if save_format in ["feather", "excel", "hdf5"] and collected_frames:
1794
+ combined_df = pd.concat(collected_frames, ignore_index=True)
1795
+ if save_format == "feather":
1796
+ combined_df.to_feather(target_path)
1797
+ elif save_format == "excel":
1798
+ combined_df.to_excel(target_path, index=False)
1799
+ elif save_format == "hdf5":
1800
+ combined_df.to_hdf(target_path, key="predictions", mode="w")
1801
+
1674
1802
  logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
1675
1803
  if return_dataframe:
1676
1804
  return (
@@ -1691,7 +1819,7 @@ class BaseModel(FeatureSet, nn.Module):
1691
1819
  target_path = resolve_save_path(
1692
1820
  path=save_path,
1693
1821
  default_dir=self.session_path,
1694
- default_name=self.model_name,
1822
+ default_name=self.model_name.upper(),
1695
1823
  suffix=".pt",
1696
1824
  add_timestamp=add_timestamp,
1697
1825
  )
@@ -1845,7 +1973,9 @@ class BaseModel(FeatureSet, nn.Module):
1845
1973
  logger.info("")
1846
1974
  logger.info(
1847
1975
  colorize(
1848
- f"Model Summary: {self.model_name}", color="bright_blue", bold=True
1976
+ f"Model Summary: {self.model_name.upper()}",
1977
+ color="bright_blue",
1978
+ bold=True,
1849
1979
  )
1850
1980
  )
1851
1981
  logger.info("")
@@ -1976,7 +2106,7 @@ class BaseModel(FeatureSet, nn.Module):
1976
2106
  logger.info("Other Settings:")
1977
2107
  logger.info(f" Early Stop Patience: {self.early_stop_patience}")
1978
2108
  logger.info(f" Max Gradient Norm: {self.max_gradient_norm}")
1979
- logger.info(f" Max Metrics Samples: {self.max_metrics_samples}")
2109
+ logger.info(f" Max Metrics Samples: {self.metrics_sample_limit}")
1980
2110
  logger.info(f" Session ID: {self.session_id}")
1981
2111
  logger.info(f" Features Config Path: {self.features_config_path}")
1982
2112
  logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
@@ -2146,7 +2276,7 @@ class BaseMatchModel(BaseModel):
2146
2276
  """
2147
2277
  if self.training_mode not in self.support_training_modes:
2148
2278
  raise ValueError(
2149
- f"{self.model_name} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
2279
+ f"{self.model_name.upper()} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
2150
2280
  )
2151
2281
 
2152
2282
  default_loss_by_mode: dict[str, str] = {
@@ -2251,9 +2381,7 @@ class BaseMatchModel(BaseModel):
2251
2381
  user_emb = self.user_tower(user_input) # [B, D]
2252
2382
  item_emb = self.item_tower(item_input) # [B, D]
2253
2383
 
2254
- return self.head(
2255
- user_emb, item_emb, similarity_fn=self.compute_similarity
2256
- )
2384
+ return self.head(user_emb, item_emb, similarity_fn=self.compute_similarity)
2257
2385
 
2258
2386
  def compute_loss(self, y_pred, y_true):
2259
2387
  if self.training_mode == "pointwise":
@@ -2309,7 +2437,7 @@ class BaseMatchModel(BaseModel):
2309
2437
  features: list,
2310
2438
  batch_size: int,
2311
2439
  num_workers: int = 0,
2312
- streaming_chunk_size: int = 10000,
2440
+ stream_chunk_size: int = 10000,
2313
2441
  ) -> DataLoader:
2314
2442
  """Prepare data loader for specific features."""
2315
2443
  if isinstance(data, DataLoader):
@@ -2330,7 +2458,7 @@ class BaseMatchModel(BaseModel):
2330
2458
  batch_size=batch_size,
2331
2459
  shuffle=False,
2332
2460
  streaming=True,
2333
- chunk_size=streaming_chunk_size,
2461
+ chunk_size=stream_chunk_size,
2334
2462
  num_workers=num_workers,
2335
2463
  )
2336
2464
  tensors = build_tensors_from_data(
@@ -2383,7 +2511,7 @@ class BaseMatchModel(BaseModel):
2383
2511
  ),
2384
2512
  batch_size: int = 512,
2385
2513
  num_workers: int = 0,
2386
- streaming_chunk_size: int = 10000,
2514
+ stream_chunk_size: int = 10000,
2387
2515
  ) -> np.ndarray:
2388
2516
  self.eval()
2389
2517
  data_loader = self.prepare_feature_data(
@@ -2391,7 +2519,7 @@ class BaseMatchModel(BaseModel):
2391
2519
  self.user_features_all,
2392
2520
  batch_size,
2393
2521
  num_workers=num_workers,
2394
- streaming_chunk_size=streaming_chunk_size,
2522
+ stream_chunk_size=stream_chunk_size,
2395
2523
  )
2396
2524
 
2397
2525
  embeddings_list = []
@@ -2417,7 +2545,7 @@ class BaseMatchModel(BaseModel):
2417
2545
  ),
2418
2546
  batch_size: int = 512,
2419
2547
  num_workers: int = 0,
2420
- streaming_chunk_size: int = 10000,
2548
+ stream_chunk_size: int = 10000,
2421
2549
  ) -> np.ndarray:
2422
2550
  self.eval()
2423
2551
  data_loader = self.prepare_feature_data(
@@ -2425,7 +2553,7 @@ class BaseMatchModel(BaseModel):
2425
2553
  self.item_features_all,
2426
2554
  batch_size,
2427
2555
  num_workers=num_workers,
2428
- streaming_chunk_size=streaming_chunk_size,
2556
+ stream_chunk_size=stream_chunk_size,
2429
2557
  )
2430
2558
 
2431
2559
  embeddings_list = []