nextrec 0.4.31__py3-none-any.whl → 0.4.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/model.py +60 -12
  3. nextrec/basic/summary.py +2 -1
  4. nextrec/cli.py +56 -41
  5. nextrec/data/batch_utils.py +2 -2
  6. nextrec/data/preprocessor.py +125 -26
  7. nextrec/models/multi_task/[pre]aitm.py +3 -3
  8. nextrec/models/multi_task/[pre]snr_trans.py +3 -3
  9. nextrec/models/multi_task/[pre]star.py +3 -3
  10. nextrec/models/multi_task/apg.py +3 -3
  11. nextrec/models/multi_task/cross_stitch.py +3 -3
  12. nextrec/models/multi_task/escm.py +3 -3
  13. nextrec/models/multi_task/esmm.py +3 -3
  14. nextrec/models/multi_task/hmoe.py +3 -3
  15. nextrec/models/multi_task/mmoe.py +3 -3
  16. nextrec/models/multi_task/pepnet.py +4 -4
  17. nextrec/models/multi_task/ple.py +3 -3
  18. nextrec/models/multi_task/poso.py +3 -3
  19. nextrec/models/multi_task/share_bottom.py +3 -3
  20. nextrec/models/ranking/afm.py +3 -2
  21. nextrec/models/ranking/autoint.py +3 -2
  22. nextrec/models/ranking/dcn.py +3 -2
  23. nextrec/models/ranking/dcn_v2.py +3 -2
  24. nextrec/models/ranking/deepfm.py +3 -2
  25. nextrec/models/ranking/dien.py +3 -2
  26. nextrec/models/ranking/din.py +3 -2
  27. nextrec/models/ranking/eulernet.py +3 -2
  28. nextrec/models/ranking/ffm.py +3 -2
  29. nextrec/models/ranking/fibinet.py +3 -2
  30. nextrec/models/ranking/fm.py +3 -2
  31. nextrec/models/ranking/lr.py +3 -2
  32. nextrec/models/ranking/masknet.py +3 -2
  33. nextrec/models/ranking/pnn.py +3 -2
  34. nextrec/models/ranking/widedeep.py +3 -2
  35. nextrec/models/ranking/xdeepfm.py +3 -2
  36. nextrec/models/tree_base/__init__.py +15 -0
  37. nextrec/models/tree_base/base.py +693 -0
  38. nextrec/models/tree_base/catboost.py +97 -0
  39. nextrec/models/tree_base/lightgbm.py +69 -0
  40. nextrec/models/tree_base/xgboost.py +61 -0
  41. nextrec/utils/config.py +1 -0
  42. nextrec/utils/types.py +2 -0
  43. {nextrec-0.4.31.dist-info → nextrec-0.4.33.dist-info}/METADATA +5 -5
  44. {nextrec-0.4.31.dist-info → nextrec-0.4.33.dist-info}/RECORD +47 -42
  45. {nextrec-0.4.31.dist-info → nextrec-0.4.33.dist-info}/licenses/LICENSE +1 -1
  46. {nextrec-0.4.31.dist-info → nextrec-0.4.33.dist-info}/WHEEL +0 -0
  47. {nextrec-0.4.31.dist-info → nextrec-0.4.33.dist-info}/entry_points.txt +0 -0
nextrec/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.31"
1
+ __version__ = "0.4.33"
nextrec/basic/model.py CHANGED
@@ -13,7 +13,7 @@ import sys
13
13
  import pickle
14
14
  import socket
15
15
  from pathlib import Path
16
- from typing import Any, Literal
16
+ from typing import Any, Literal, cast, overload
17
17
 
18
18
  import numpy as np
19
19
  import pandas as pd
@@ -97,6 +97,7 @@ from nextrec.utils.types import (
97
97
  SchedulerName,
98
98
  TrainingModeName,
99
99
  TaskTypeName,
100
+ TaskTypeInput,
100
101
  MetricsName,
101
102
  )
102
103
 
@@ -119,7 +120,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
119
120
  sequence_features: list[SequenceFeature] | None = None,
120
121
  target: list[str] | str | None = None,
121
122
  id_columns: list[str] | str | None = None,
122
- task: TaskTypeName | list[TaskTypeName] | None = None,
123
+ task: TaskTypeInput | list[TaskTypeInput] | None = None,
123
124
  training_mode: TrainingModeName | list[TrainingModeName] | None = None,
124
125
  embedding_l1_reg: float = 0.0,
125
126
  dense_l1_reg: float = 0.0,
@@ -193,7 +194,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
193
194
  dense_features, sparse_features, sequence_features, target, id_columns
194
195
  )
195
196
 
196
- self.task = task or self.default_task
197
+ self.task = cast(TaskTypeName | list[TaskTypeName], task or self.default_task)
197
198
  self.nums_task = len(self.task) if isinstance(self.task, list) else 1
198
199
 
199
200
  training_mode = training_mode or "pointwise"
@@ -932,6 +933,13 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
932
933
 
933
934
  existing_callbacks = self.callbacks.callbacks
934
935
 
936
+ has_validation = valid_data is not None or valid_split is not None
937
+ checkpoint_monitor = monitor_metric
938
+ checkpoint_mode = self.best_metrics_mode
939
+ if not has_validation:
940
+ checkpoint_monitor = "loss"
941
+ checkpoint_mode = "min"
942
+
935
943
  if self.early_stop_patience > 0 and not any(
936
944
  isinstance(cb, EarlyStopper) for cb in existing_callbacks
937
945
  ):
@@ -945,6 +953,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
945
953
  )
946
954
  )
947
955
 
956
+ has_validation = valid_data is not None or valid_split is not None
957
+
948
958
  if self.is_main_process and not any(
949
959
  isinstance(cb, CheckpointSaver) for cb in existing_callbacks
950
960
  ):
@@ -952,9 +962,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
952
962
  CheckpointSaver(
953
963
  best_path=self.best_path,
954
964
  checkpoint_path=self.checkpoint_path,
955
- monitor=monitor_metric,
956
- mode=self.best_metrics_mode,
957
- save_best_only=True,
965
+ monitor=checkpoint_monitor,
966
+ mode=checkpoint_mode,
967
+ save_best_only=has_validation,
958
968
  verbose=1,
959
969
  )
960
970
  )
@@ -1245,11 +1255,6 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1245
1255
  epoch_logs[f"val_{k}"] = v
1246
1256
  else:
1247
1257
  epoch_logs = {**train_log_payload}
1248
- if self.is_main_process:
1249
- self.save_model(
1250
- self.checkpoint_path, add_timestamp=False, verbose=False
1251
- )
1252
- self.best_checkpoint_path = self.checkpoint_path
1253
1258
 
1254
1259
  # Call on_epoch_end for all callbacks (handles early stopping, checkpointing, lr scheduling)
1255
1260
  self.callbacks.on_epoch_end(epoch, epoch_logs)
@@ -1623,6 +1628,49 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1623
1628
  )
1624
1629
  return metrics_dict
1625
1630
 
1631
+ @overload
1632
+ def predict(
1633
+ self,
1634
+ data: str | dict | pd.DataFrame | DataLoader,
1635
+ batch_size: int = 32,
1636
+ save_path: str | os.PathLike | None = None,
1637
+ save_format: str = "csv",
1638
+ include_ids: bool | None = None,
1639
+ id_columns: str | list[str] | None = None,
1640
+ return_dataframe: Literal[True] = True,
1641
+ stream_chunk_size: int = 10000,
1642
+ num_workers: int = 0,
1643
+ ) -> pd.DataFrame: ...
1644
+
1645
+ @overload
1646
+ def predict(
1647
+ self,
1648
+ data: str | dict | pd.DataFrame | DataLoader,
1649
+ batch_size: int = 32,
1650
+ save_path: None = None,
1651
+ save_format: str = "csv",
1652
+ include_ids: bool | None = None,
1653
+ id_columns: str | list[str] | None = None,
1654
+ return_dataframe: Literal[False] = False,
1655
+ stream_chunk_size: int = 10000,
1656
+ num_workers: int = 0,
1657
+ ) -> np.ndarray: ...
1658
+
1659
+ @overload
1660
+ def predict(
1661
+ self,
1662
+ data: str | dict | pd.DataFrame | DataLoader,
1663
+ batch_size: int = 32,
1664
+ *,
1665
+ save_path: str | os.PathLike,
1666
+ save_format: str = "csv",
1667
+ include_ids: bool | None = None,
1668
+ id_columns: str | list[str] | None = None,
1669
+ return_dataframe: Literal[False] = False,
1670
+ stream_chunk_size: int = 10000,
1671
+ num_workers: int = 0,
1672
+ ) -> Path: ...
1673
+
1626
1674
  def predict(
1627
1675
  self,
1628
1676
  data: str | dict | pd.DataFrame | DataLoader,
@@ -2225,7 +2273,7 @@ class BaseMatchModel(BaseModel):
2225
2273
  dense_l2_reg: float = 0.0,
2226
2274
  target: list[str] | str | None = "label",
2227
2275
  id_columns: list[str] | str | None = None,
2228
- task: str | list[str] | None = None,
2276
+ task: TaskTypeInput | list[TaskTypeInput] | None = None,
2229
2277
  session_id: str | None = None,
2230
2278
  distributed: bool = False,
2231
2279
  rank: int | None = None,
nextrec/basic/summary.py CHANGED
@@ -73,7 +73,8 @@ class SummarySet:
73
73
  def build_data_summary(
74
74
  self, data: Any, data_loader: DataLoader | None, sample_key: str
75
75
  ):
76
- dataset = data_loader.dataset if data_loader else None
76
+
77
+ dataset = data_loader.dataset if data_loader is not None else None
77
78
 
78
79
  train_size = get_data_length(dataset)
79
80
  if train_size is None:
nextrec/cli.py CHANGED
@@ -152,16 +152,19 @@ def train_model(train_config_path: str) -> None:
152
152
  )
153
153
  if data_cfg.get("valid_ratio") is not None:
154
154
  logger.info(format_kv("Valid ratio", data_cfg.get("valid_ratio")))
155
- if data_cfg.get("val_path") or data_cfg.get("valid_path"):
155
+ if data_cfg.get("valid_path"):
156
156
  logger.info(
157
157
  format_kv(
158
158
  "Validation path",
159
159
  resolve_path(
160
- data_cfg.get("val_path") or data_cfg.get("valid_path"), config_dir
160
+ data_cfg.get("valid_path"), config_dir
161
161
  ),
162
162
  )
163
163
  )
164
164
 
165
+ # Determine validation dataset path early for streaming split / fitting
166
+ val_data_path = data_cfg.get("valid_path")
167
+
165
168
  if streaming:
166
169
  file_paths, file_type = resolve_file_paths(str(data_path))
167
170
  log_kv_lines(
@@ -180,6 +183,34 @@ def train_model(train_config_path: str) -> None:
180
183
  raise ValueError(f"Data file is empty: {first_file}") from exc
181
184
  df_columns = list(first_chunk.columns)
182
185
 
186
+ # Decide training/validation file lists before fitting processor, to avoid
187
+ # leaking validation statistics into preprocessing (scalers/encoders).
188
+ streaming_train_files = file_paths
189
+ streaming_valid_ratio = data_cfg.get("valid_ratio")
190
+ if val_data_path:
191
+ streaming_valid_files = None
192
+ elif streaming_valid_ratio is not None:
193
+ ratio = float(streaming_valid_ratio)
194
+ if not (0 < ratio < 1):
195
+ raise ValueError(
196
+ f"[NextRec CLI Error] Valid_ratio must be between 0 and 1, current value is {streaming_valid_ratio}"
197
+ )
198
+ total_files = len(file_paths)
199
+ if total_files < 2:
200
+ raise ValueError(
201
+ "[NextRec CLI Error] Must provide valid_path or increase the number of data files. At least 2 files are required for streaming validation split."
202
+ )
203
+ val_count = max(1, int(round(total_files * ratio)))
204
+ if val_count >= total_files:
205
+ val_count = total_files - 1
206
+ streaming_valid_files = file_paths[-val_count:]
207
+ streaming_train_files = file_paths[:-val_count]
208
+ logger.info(
209
+ f"Split files for streaming training and validation using valid_ratio={ratio:.3f}: training {len(streaming_train_files)} files, validation {len(streaming_valid_files)} files"
210
+ )
211
+ else:
212
+ streaming_valid_files = None
213
+
183
214
  else:
184
215
  df = read_table(data_path, data_cfg.get("format"))
185
216
  logger.info(format_kv("Rows", len(df)))
@@ -215,7 +246,13 @@ def train_model(train_config_path: str) -> None:
215
246
  )
216
247
 
217
248
  if streaming:
218
- processor.fit(str(data_path), chunk_size=dataloader_chunk_size)
249
+ if file_type is None:
250
+ raise ValueError("[NextRec CLI Error] Streaming mode requires a valid file_type")
251
+ processor.fit_from_files(
252
+ file_paths=streaming_train_files or file_paths,
253
+ file_type=file_type,
254
+ chunk_size=dataloader_chunk_size,
255
+ )
219
256
  processed = None
220
257
  df = None # type: ignore[assignment]
221
258
  else:
@@ -232,34 +269,6 @@ def train_model(train_config_path: str) -> None:
232
269
  sequence_names,
233
270
  )
234
271
 
235
- # Check if validation dataset path is specified
236
- val_data_path = data_cfg.get("val_path") or data_cfg.get("valid_path")
237
- if streaming:
238
- if not file_paths:
239
- file_paths, file_type = resolve_file_paths(str(data_path))
240
- streaming_train_files = file_paths
241
- streaming_valid_ratio = data_cfg.get("valid_ratio")
242
- if val_data_path:
243
- streaming_valid_files = None
244
- elif streaming_valid_ratio is not None:
245
- ratio = float(streaming_valid_ratio)
246
- if not (0 < ratio < 1):
247
- raise ValueError(
248
- f"[NextRec CLI Error] Valid_ratio must be between 0 and 1, current value is {streaming_valid_ratio}"
249
- )
250
- total_files = len(file_paths)
251
- if total_files < 2:
252
- raise ValueError(
253
- "[NextRec CLI Error] Must provide val_path or increase the number of data files. At least 2 files are required for streaming validation split."
254
- )
255
- val_count = max(1, int(round(total_files * ratio)))
256
- if val_count >= total_files:
257
- val_count = total_files - 1
258
- streaming_valid_files = file_paths[-val_count:]
259
- streaming_train_files = file_paths[:-val_count]
260
- logger.info(
261
- f"Split files for streaming training and validation using valid_ratio={ratio:.3f}: training {len(streaming_train_files)} files, validation {len(streaming_valid_files)} files"
262
- )
263
272
  train_data: Dict[str, Any]
264
273
  valid_data: Dict[str, Any] | None
265
274
 
@@ -682,16 +691,22 @@ Examples:
682
691
  if not args.mode:
683
692
  parser.error("[NextRec CLI Error] --mode is required (train|predict)")
684
693
 
685
- if args.mode == "train":
686
- config_path = args.train_config
687
- if not config_path:
688
- parser.error("[NextRec CLI Error] train mode requires --train_config")
689
- train_model(config_path)
690
- else:
691
- config_path = args.predict_config
692
- if not config_path:
693
- parser.error("[NextRec CLI Error] predict mode requires --predict_config")
694
- predict_model(config_path)
694
+ try:
695
+ if args.mode == "train":
696
+ config_path = args.train_config
697
+ if not config_path:
698
+ parser.error("[NextRec CLI Error] train mode requires --train_config")
699
+ train_model(config_path)
700
+ else:
701
+ config_path = args.predict_config
702
+ if not config_path:
703
+ parser.error(
704
+ "[NextRec CLI Error] predict mode requires --predict_config"
705
+ )
706
+ predict_model(config_path)
707
+ except Exception:
708
+ logging.getLogger(__name__).exception("[NextRec CLI Error] Unhandled exception")
709
+ raise
695
710
 
696
711
 
697
712
  if __name__ == "__main__":
@@ -12,7 +12,7 @@ import torch
12
12
 
13
13
 
14
14
  def stack_section(batch: list[dict], section: Literal["features", "labels", "ids"]):
15
- """
15
+ """
16
16
  input example:
17
17
  batch = [
18
18
  {"features": {"f1": tensor1, "f2": tensor2}, "labels": {"label": tensor3}},
@@ -24,7 +24,7 @@ def stack_section(batch: list[dict], section: Literal["features", "labels", "ids
24
24
  "f1": torch.stack([tensor1, tensor4], dim=0),
25
25
  "f2": torch.stack([tensor2, tensor5], dim=0),
26
26
  }
27
-
27
+
28
28
  """
29
29
  entries = [item.get(section) for item in batch if item.get(section) is not None]
30
30
  if not entries:
@@ -13,7 +13,7 @@ import logging
13
13
  import os
14
14
  import pickle
15
15
  from pathlib import Path
16
- from typing import Any, Dict, Literal, Optional, Union
16
+ from typing import Any, Dict, Literal, Optional, Union, overload
17
17
 
18
18
  import numpy as np
19
19
  import pandas as pd
@@ -566,35 +566,16 @@ class DataProcessor(FeatureSet):
566
566
  return [str(v) for v in value]
567
567
  return [str(value)]
568
568
 
569
- def fit_from_path(self, path: str, chunk_size: int) -> "DataProcessor":
570
- """
571
- Fit processor statistics by streaming files to reduce memory usage.
572
-
573
- Args:
574
- path (str): File or directory path.
575
- chunk_size (int): Number of rows per chunk.
576
-
577
- Returns:
578
- DataProcessor: Fitted DataProcessor instance.
579
- """
569
+ def fit_from_file_paths(
570
+ self, file_paths: list[str], file_type: str, chunk_size: int
571
+ ) -> "DataProcessor":
580
572
  logger = logging.getLogger()
581
- logger.info(
582
- colorize(
583
- "Fitting DataProcessor (streaming path mode)...",
584
- color="cyan",
585
- bold=True,
586
- )
587
- )
588
- for config in self.sparse_features.values():
589
- config.pop("_min_freq_logged", None)
590
- for config in self.sequence_features.values():
591
- config.pop("_min_freq_logged", None)
592
- file_paths, file_type = resolve_file_paths(path)
573
+ if not file_paths:
574
+ raise ValueError("[DataProcessor Error] Empty file list for streaming fit")
593
575
  if not check_streaming_support(file_type):
594
576
  raise ValueError(
595
577
  f"[DataProcessor Error] Format '{file_type}' does not support streaming. "
596
- "fit_from_path only supports streaming formats (csv, parquet) to avoid high memory usage. "
597
- "Use fit(dataframe) with in-memory data or convert the data format."
578
+ "Streaming fit only supports csv, parquet to avoid high memory usage."
598
579
  )
599
580
 
600
581
  numeric_acc = {}
@@ -636,6 +617,7 @@ class DataProcessor(FeatureSet):
636
617
  target_values: Dict[str, set[Any]] = {
637
618
  name: set() for name in self.target_features.keys()
638
619
  }
620
+
639
621
  missing_features = set()
640
622
  for file_path in file_paths:
641
623
  for chunk in iter_file_chunks(file_path, file_type, chunk_size):
@@ -702,10 +684,12 @@ class DataProcessor(FeatureSet):
702
684
  for name in self.target_features.keys() & columns:
703
685
  vals = chunk[name].dropna().tolist()
704
686
  target_values[name].update(vals)
687
+
705
688
  if missing_features:
706
689
  logger.warning(
707
690
  f"The following configured features were not found in provided files: {sorted(missing_features)}"
708
691
  )
692
+
709
693
  # finalize numeric scalers
710
694
  for name, config in self.numeric_features.items():
711
695
  acc = numeric_acc[name]
@@ -895,6 +879,91 @@ class DataProcessor(FeatureSet):
895
879
  )
896
880
  return self
897
881
 
882
+ def fit_from_files(
883
+ self, file_paths: list[str], file_type: str, chunk_size: int
884
+ ) -> "DataProcessor":
885
+ """Fit processor statistics by streaming an explicit list of files.
886
+
887
+ This is useful when you want to fit statistics on training files only (exclude
888
+ validation files) in streaming mode.
889
+ """
890
+ logger = logging.getLogger()
891
+ logger.info(
892
+ colorize(
893
+ "Fitting DataProcessor (streaming files mode)...",
894
+ color="cyan",
895
+ bold=True,
896
+ )
897
+ )
898
+ for config in self.sparse_features.values():
899
+ config.pop("_min_freq_logged", None)
900
+ for config in self.sequence_features.values():
901
+ config.pop("_min_freq_logged", None)
902
+ uses_robust = any(
903
+ cfg.get("scaler") == "robust" for cfg in self.numeric_features.values()
904
+ )
905
+ if uses_robust:
906
+ logger.warning(
907
+ "Robust scaler requires full data; loading provided files into memory. "
908
+ "Consider smaller chunk_size or different scaler if memory is limited."
909
+ )
910
+ frames = [read_table(p, file_type) for p in file_paths]
911
+ df = pd.concat(frames, ignore_index=True) if len(frames) > 1 else frames[0]
912
+ return self.fit(df)
913
+ return self.fit_from_file_paths(file_paths=file_paths, file_type=file_type, chunk_size=chunk_size)
914
+
915
+ def fit_from_path(self, path: str, chunk_size: int) -> "DataProcessor":
916
+ """
917
+ Fit processor statistics by streaming files to reduce memory usage.
918
+
919
+ Args:
920
+ path (str): File or directory path.
921
+ chunk_size (int): Number of rows per chunk.
922
+
923
+ Returns:
924
+ DataProcessor: Fitted DataProcessor instance.
925
+ """
926
+ logger = logging.getLogger()
927
+ logger.info(
928
+ colorize(
929
+ "Fitting DataProcessor (streaming path mode)...",
930
+ color="cyan",
931
+ bold=True,
932
+ )
933
+ )
934
+ for config in self.sparse_features.values():
935
+ config.pop("_min_freq_logged", None)
936
+ for config in self.sequence_features.values():
937
+ config.pop("_min_freq_logged", None)
938
+ file_paths, file_type = resolve_file_paths(path)
939
+ return self.fit_from_file_paths(
940
+ file_paths=file_paths,
941
+ file_type=file_type,
942
+ chunk_size=chunk_size,
943
+ )
944
+
945
+ @overload
946
+ def transform_in_memory(
947
+ self,
948
+ data: Union[pd.DataFrame, Dict[str, Any]],
949
+ return_dict: Literal[True],
950
+ persist: bool,
951
+ save_format: Optional[str],
952
+ output_path: Optional[str],
953
+ warn_missing: bool = True,
954
+ ) -> Dict[str, np.ndarray]: ...
955
+
956
+ @overload
957
+ def transform_in_memory(
958
+ self,
959
+ data: Union[pd.DataFrame, Dict[str, Any]],
960
+ return_dict: Literal[False],
961
+ persist: bool,
962
+ save_format: Optional[str],
963
+ output_path: Optional[str],
964
+ warn_missing: bool = True,
965
+ ) -> pd.DataFrame: ...
966
+
898
967
  def transform_in_memory(
899
968
  self,
900
969
  data: Union[pd.DataFrame, Dict[str, Any]],
@@ -1238,6 +1307,36 @@ class DataProcessor(FeatureSet):
1238
1307
  self.is_fitted = True
1239
1308
  return self
1240
1309
 
1310
+ @overload
1311
+ def transform(
1312
+ self,
1313
+ data: Union[pd.DataFrame, Dict[str, Any]],
1314
+ return_dict: Literal[True] = True,
1315
+ save_format: Optional[str] = None,
1316
+ output_path: Optional[str] = None,
1317
+ chunk_size: int = 200000,
1318
+ ) -> Dict[str, np.ndarray]: ...
1319
+
1320
+ @overload
1321
+ def transform(
1322
+ self,
1323
+ data: Union[pd.DataFrame, Dict[str, Any]],
1324
+ return_dict: Literal[False] = False,
1325
+ save_format: Optional[str] = None,
1326
+ output_path: Optional[str] = None,
1327
+ chunk_size: int = 200000,
1328
+ ) -> pd.DataFrame: ...
1329
+
1330
+ @overload
1331
+ def transform(
1332
+ self,
1333
+ data: str | os.PathLike,
1334
+ return_dict: Literal[False] = False,
1335
+ save_format: Optional[str] = None,
1336
+ output_path: Optional[str] = None,
1337
+ chunk_size: int = 200000,
1338
+ ) -> list[str]: ...
1339
+
1241
1340
  def transform(
1242
1341
  self,
1243
1342
  data: Union[pd.DataFrame, Dict[str, Any], str, os.PathLike],
@@ -1,6 +1,6 @@
1
1
  """
2
2
  Date: create on 01/01/2026 - prerelease version: need to overwrite compute_loss later
3
- Checkpoint: edit on 01/01/2026
3
+ Checkpoint: edit on 01/14/2026
4
4
  Author: Yang Zhou, zyaztec@gmail.com
5
5
  Reference:
6
6
  - [1] Xi D, Chen Z, Yan P, Zhang Y, Zhu Y, Zhuang F, Chen Y. Modeling the Sequential Dependence among Audience Multi-step Conversions with Multi-task Learning in Targeted Display Advertising. Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery & Data Mining (KDD ’21), 2021, pp. 3745–3755.
@@ -20,7 +20,7 @@ from nextrec.basic.layers import MLP, EmbeddingLayer
20
20
  from nextrec.basic.heads import TaskHead
21
21
  from nextrec.basic.model import BaseModel
22
22
  from nextrec.utils.model import get_mlp_output_dim
23
- from nextrec.utils.types import TaskTypeName
23
+ from nextrec.utils.types import TaskTypeInput
24
24
 
25
25
 
26
26
  class AITMTransfer(nn.Module):
@@ -76,7 +76,7 @@ class AITM(BaseModel):
76
76
  tower_mlp_params_list: list[dict] | None = None,
77
77
  calibrator_alpha: float = 0.1,
78
78
  target: list[str] | str | None = None,
79
- task: list[TaskTypeName] | None = None,
79
+ task: list[TaskTypeInput] | None = None,
80
80
  **kwargs,
81
81
  ):
82
82
  dense_features = dense_features or []
@@ -1,6 +1,6 @@
1
1
  """
2
2
  Date: create on 01/01/2026 - prerelease version: still need to align with the source paper
3
- Checkpoint: edit on 01/01/2026
3
+ Checkpoint: edit on 01/14/2026
4
4
  Author: Yang Zhou, zyaztec@gmail.com
5
5
  Reference:
6
6
  - [1] Ma J, Zhao Z, Chen J, Li A, Hong L, Chi EH. SNR: Sub-Network Routing for Flexible Parameter Sharing in Multi-Task Learning in E-Commerce by Exploiting Task Relationships in the Label Space. Proceedings of the 33rd AAAI Conference on Artificial Intelligence (AAAI 2019), 2019, pp. 216-223.
@@ -22,7 +22,7 @@ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
22
22
  from nextrec.basic.layers import EmbeddingLayer, MLP
23
23
  from nextrec.basic.heads import TaskHead
24
24
  from nextrec.basic.model import BaseModel
25
- from nextrec.utils.types import TaskTypeName
25
+ from nextrec.utils.types import TaskTypeInput, TaskTypeName
26
26
 
27
27
 
28
28
  class SNRTransGate(nn.Module):
@@ -101,7 +101,7 @@ class SNRTrans(BaseModel):
101
101
  num_experts: int = 4,
102
102
  tower_mlp_params_list: list[dict] | None = None,
103
103
  target: list[str] | str | None = None,
104
- task: TaskTypeName | list[TaskTypeName] | None = None,
104
+ task: TaskTypeInput | list[TaskTypeInput] | None = None,
105
105
  **kwargs,
106
106
  ) -> None:
107
107
  dense_features = dense_features or []
@@ -1,6 +1,6 @@
1
1
  """
2
2
  Date: create on 01/01/2026 - prerelease version: still need to align with the source paper
3
- Checkpoint: edit on 01/01/2026
3
+ Checkpoint: edit on 01/14/2026
4
4
  Author: Yang Zhou, zyaztec@gmail.com
5
5
  Reference:
6
6
  - [1] Sheng XR, Zhao L, Zhou G, Ding X, Dai B, Luo Q, Yang S, Lv J, Zhang C, Deng H, Zhu X. One Model to Serve All: Star Topology Adaptive Recommender for Multi-Domain CTR Prediction. arXiv preprint arXiv:2101.11427, 2021.
@@ -22,7 +22,7 @@ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
22
22
  from nextrec.basic.heads import TaskHead
23
23
  from nextrec.basic.layers import DomainBatchNorm, EmbeddingLayer
24
24
  from nextrec.basic.model import BaseModel
25
- from nextrec.utils.types import TaskTypeName
25
+ from nextrec.utils.types import TaskTypeInput, TaskTypeName
26
26
 
27
27
 
28
28
  class SharedSpecificLinear(nn.Module):
@@ -73,7 +73,7 @@ class STAR(BaseModel):
73
73
  sparse_features: list[SparseFeature] | None = None,
74
74
  sequence_features: list[SequenceFeature] | None = None,
75
75
  target: list[str] | str | None = None,
76
- task: TaskTypeName | list[TaskTypeName] | None = None,
76
+ task: TaskTypeInput | list[TaskTypeInput] | None = None,
77
77
  mlp_params: dict | None = None,
78
78
  use_shared: bool = True,
79
79
  **kwargs,
@@ -1,6 +1,6 @@
1
1
  """
2
2
  Date: create on 01/01/2026
3
- Checkpoint: edit on 01/01/2026
3
+ Checkpoint: edit on 01/14/2026
4
4
  Author: Yang Zhou, zyaztec@gmail.com
5
5
  Reference:
6
6
  - [1] Yan B, Wang P, Zhang K, Li F, Deng H, Xu J, Zheng B. APG: Adaptive Parameter Generation Network for Click-Through Rate Prediction. Advances in Neural Information Processing Systems 35 (NeurIPS 2022), 2022.
@@ -20,7 +20,7 @@ from nextrec.basic.layers import EmbeddingLayer, MLP
20
20
  from nextrec.basic.heads import TaskHead
21
21
  from nextrec.basic.model import BaseModel
22
22
  from nextrec.utils.model import select_features
23
- from nextrec.utils.types import ActivationName, TaskTypeName
23
+ from nextrec.utils.types import ActivationName, TaskTypeInput, TaskTypeName
24
24
 
25
25
 
26
26
  class APGLayer(nn.Module):
@@ -233,7 +233,7 @@ class APG(BaseModel):
233
233
  sparse_features: list[SparseFeature] | None = None,
234
234
  sequence_features: list[SequenceFeature] | None = None,
235
235
  target: list[str] | str | None = None,
236
- task: TaskTypeName | list[TaskTypeName] | None = None,
236
+ task: TaskTypeInput | list[TaskTypeInput] | None = None,
237
237
  mlp_params: dict | None = None,
238
238
  inner_activation: ActivationName | None = None,
239
239
  generate_activation: ActivationName | None = None,
@@ -1,6 +1,6 @@
1
1
  """
2
2
  Date: create on 01/01/2026
3
- Checkpoint: edit on 01/01/2026
3
+ Checkpoint: edit on 01/14/2026
4
4
  Author: Yang Zhou, zyaztec@gmail.com
5
5
  Reference:
6
6
  - [1] Misra I, Shrivastava A, Gupta A, Hebert M. Cross-Stitch Networks for Multi-Task Learning. Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR 2016), 2016, pp. 3994–4003.
@@ -21,7 +21,7 @@ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
21
21
  from nextrec.basic.layers import EmbeddingLayer, MLP
22
22
  from nextrec.basic.heads import TaskHead
23
23
  from nextrec.basic.model import BaseModel
24
- from nextrec.utils.types import TaskTypeName
24
+ from nextrec.utils.types import TaskTypeInput, TaskTypeName
25
25
 
26
26
 
27
27
  class CrossStitchLayer(nn.Module):
@@ -76,7 +76,7 @@ class CrossStitch(BaseModel):
76
76
  sparse_features: list[SparseFeature] | None = None,
77
77
  sequence_features: list[SequenceFeature] | None = None,
78
78
  target: list[str] | str | None = None,
79
- task: TaskTypeName | list[TaskTypeName] | None = None,
79
+ task: TaskTypeInput | list[TaskTypeInput] | None = None,
80
80
  shared_mlp_params: dict | None = None,
81
81
  task_mlp_params: dict | None = None,
82
82
  tower_mlp_params: dict | None = None,
@@ -1,6 +1,6 @@
1
1
  """
2
2
  Date: create on 01/01/2026
3
- Checkpoint: edit on 01/01/2026
3
+ Checkpoint: edit on 01/14/2026
4
4
  Author: Yang Zhou, zyaztec@gmail.com
5
5
  Reference:
6
6
  - [1] Wang H, Chang T-W, Liu T, Huang J, Chen Z, Yu C, Li R, Chu W. ESCM²: Entire Space Counterfactual Multi-Task Model for Post-Click Conversion Rate Estimation. Proceedings of the 45th International ACM SIGIR Conference on Research and Development in Information Retrieval (SIGIR ’22), 2022:363–372.
@@ -23,7 +23,7 @@ from nextrec.basic.layers import EmbeddingLayer, MLP
23
23
  from nextrec.basic.model import BaseModel
24
24
  from nextrec.loss.grad_norm import get_grad_norm_shared_params
25
25
  from nextrec.utils.model import compute_ranking_loss
26
- from nextrec.utils.types import TaskTypeName
26
+ from nextrec.utils.types import TaskTypeInput, TaskTypeName
27
27
 
28
28
 
29
29
  class ESCM(BaseModel):
@@ -52,7 +52,7 @@ class ESCM(BaseModel):
52
52
  imp_mlp_params: dict | None = None,
53
53
  use_dr: bool = False,
54
54
  target: list[str] | str | None = None,
55
- task: TaskTypeName | list[TaskTypeName] | None = None,
55
+ task: TaskTypeInput | list[TaskTypeInput] | None = None,
56
56
  **kwargs,
57
57
  ) -> None:
58
58
  dense_features = dense_features or []