nextrec 0.3.5__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. nextrec/__init__.py +0 -30
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/layers.py +32 -15
  4. nextrec/basic/loggers.py +1 -1
  5. nextrec/basic/model.py +440 -189
  6. nextrec/basic/session.py +4 -2
  7. nextrec/data/__init__.py +0 -25
  8. nextrec/data/data_processing.py +31 -19
  9. nextrec/data/dataloader.py +51 -16
  10. nextrec/models/generative/__init__.py +0 -5
  11. nextrec/models/generative/hstu.py +3 -2
  12. nextrec/models/match/__init__.py +0 -13
  13. nextrec/models/match/dssm.py +0 -1
  14. nextrec/models/match/dssm_v2.py +0 -1
  15. nextrec/models/match/mind.py +0 -1
  16. nextrec/models/match/sdm.py +0 -1
  17. nextrec/models/match/youtube_dnn.py +0 -1
  18. nextrec/models/multi_task/__init__.py +0 -0
  19. nextrec/models/multi_task/esmm.py +5 -7
  20. nextrec/models/multi_task/mmoe.py +10 -6
  21. nextrec/models/multi_task/ple.py +10 -6
  22. nextrec/models/multi_task/poso.py +9 -6
  23. nextrec/models/multi_task/share_bottom.py +10 -7
  24. nextrec/models/ranking/__init__.py +0 -27
  25. nextrec/models/ranking/afm.py +113 -21
  26. nextrec/models/ranking/autoint.py +15 -9
  27. nextrec/models/ranking/dcn.py +8 -11
  28. nextrec/models/ranking/deepfm.py +5 -5
  29. nextrec/models/ranking/dien.py +4 -4
  30. nextrec/models/ranking/din.py +4 -4
  31. nextrec/models/ranking/fibinet.py +4 -4
  32. nextrec/models/ranking/fm.py +4 -4
  33. nextrec/models/ranking/masknet.py +4 -5
  34. nextrec/models/ranking/pnn.py +4 -4
  35. nextrec/models/ranking/widedeep.py +4 -4
  36. nextrec/models/ranking/xdeepfm.py +4 -4
  37. nextrec/utils/__init__.py +7 -3
  38. nextrec/utils/device.py +32 -1
  39. nextrec/utils/distributed.py +114 -0
  40. nextrec/utils/synthetic_data.py +413 -0
  41. {nextrec-0.3.5.dist-info → nextrec-0.4.1.dist-info}/METADATA +15 -5
  42. nextrec-0.4.1.dist-info/RECORD +66 -0
  43. nextrec-0.3.5.dist-info/RECORD +0 -63
  44. {nextrec-0.3.5.dist-info → nextrec-0.4.1.dist-info}/WHEEL +0 -0
  45. {nextrec-0.3.5.dist-info → nextrec-0.4.1.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/session.py CHANGED
@@ -22,6 +22,7 @@ class Session:
22
22
 
23
23
  experiment_id: str
24
24
  root: Path
25
+ log_basename: str # The base name for log files, without path separators
25
26
 
26
27
  @property
27
28
  def logs_dir(self) -> Path:
@@ -60,7 +61,6 @@ class Session:
60
61
  return path
61
62
 
62
63
  def create_session(experiment_id: str | Path | None = None) -> Session:
63
- """Create a :class:`Session` instance with prepared directories."""
64
64
 
65
65
  if experiment_id is not None and str(experiment_id).strip():
66
66
  exp_id = str(experiment_id).strip()
@@ -68,6 +68,8 @@ def create_session(experiment_id: str | Path | None = None) -> Session:
68
68
  # Use local time for session naming
69
69
  exp_id = "nextrec_session_" + datetime.now().strftime("%Y%m%d")
70
70
 
71
+ log_basename = Path(exp_id).name if exp_id else exp_id
72
+
71
73
  if (
72
74
  os.getenv("PYTEST_CURRENT_TEST")
73
75
  or os.getenv("PYTEST_RUNNING")
@@ -82,7 +84,7 @@ def create_session(experiment_id: str | Path | None = None) -> Session:
82
84
  session_path.mkdir(parents=True, exist_ok=True)
83
85
  root = session_path.resolve()
84
86
 
85
- return Session(experiment_id=exp_id, root=root)
87
+ return Session(experiment_id=exp_id, root=root, log_basename=log_basename)
86
88
 
87
89
  def resolve_save_path(
88
90
  path: str | os.PathLike | Path | None,
nextrec/data/__init__.py CHANGED
@@ -1,22 +1,4 @@
1
- """
2
- Data utilities package for NextRec
3
-
4
- This package provides data processing and manipulation utilities organized by category:
5
- - batch_utils: Batch collation and processing
6
- - data_processing: Data manipulation and user ID extraction
7
- - data_utils: Legacy module (re-exports from specialized modules)
8
- - dataloader: Dataset and DataLoader implementations
9
- - preprocessor: Data preprocessing pipeline
10
-
11
- Date: create on 13/11/2025
12
- Last update: 03/12/2025 (refactored)
13
- Author: Yang Zhou, zyaztec@gmail.com
14
- """
15
-
16
- # Batch utilities
17
1
  from nextrec.data.batch_utils import collate_fn, batch_to_dict, stack_section
18
-
19
- # Data processing utilities
20
2
  from nextrec.data.data_processing import (
21
3
  get_column_data,
22
4
  split_dict_random,
@@ -24,7 +6,6 @@ from nextrec.data.data_processing import (
24
6
  get_user_ids,
25
7
  )
26
8
 
27
- # File utilities (from utils package)
28
9
  from nextrec.utils.file import (
29
10
  resolve_file_paths,
30
11
  iter_file_chunks,
@@ -33,7 +14,6 @@ from nextrec.utils.file import (
33
14
  default_output_dir,
34
15
  )
35
16
 
36
- # DataLoader components
37
17
  from nextrec.data.dataloader import (
38
18
  TensorDictDataset,
39
19
  FileDataset,
@@ -41,13 +21,8 @@ from nextrec.data.dataloader import (
41
21
  build_tensors_from_data,
42
22
  )
43
23
 
44
- # Preprocessor
45
24
  from nextrec.data.preprocessor import DataProcessor
46
-
47
- # Feature definitions
48
25
  from nextrec.basic.features import FeatureSet
49
-
50
- # Legacy module (for backward compatibility)
51
26
  from nextrec.data import data_utils
52
27
 
53
28
  __all__ = [
@@ -11,7 +11,10 @@ import pandas as pd
11
11
  from typing import Any, Mapping
12
12
 
13
13
 
14
- def get_column_data(data: dict | pd.DataFrame, name: str):
14
+ def get_column_data(
15
+ data: dict | pd.DataFrame,
16
+ name: str):
17
+
15
18
  if isinstance(data, dict):
16
19
  return data[name] if name in data else None
17
20
  elif isinstance(data, pd.DataFrame):
@@ -24,10 +27,11 @@ def get_column_data(data: dict | pd.DataFrame, name: str):
24
27
  raise KeyError(f"Unsupported data type for extracting column {name}")
25
28
 
26
29
  def split_dict_random(
27
- data_dict: dict,
28
- test_size: float = 0.2,
29
- random_state: int | None = None
30
- ):
30
+ data_dict: dict,
31
+ test_size: float = 0.2,
32
+ random_state: int | None = None
33
+ ):
34
+
31
35
  lengths = [len(v) for v in data_dict.values()]
32
36
  if len(set(lengths)) != 1:
33
37
  raise ValueError(f"Length mismatch: {lengths}")
@@ -51,18 +55,27 @@ def split_dict_random(
51
55
  test_dict = {k: take(v, test_idx) for k, v in data_dict.items()}
52
56
  return train_dict, test_dict
53
57
 
58
+ def split_data(
59
+ df: pd.DataFrame,
60
+ test_size: float = 0.2
61
+ ) -> tuple[pd.DataFrame, pd.DataFrame]:
62
+
63
+ split_idx = int(len(df) * (1 - test_size))
64
+ train_df = df.iloc[:split_idx].reset_index(drop=True)
65
+ valid_df = df.iloc[split_idx:].reset_index(drop=True)
66
+ return train_df, valid_df
54
67
 
55
68
  def build_eval_candidates(
56
- df_all: pd.DataFrame,
57
- user_col: str,
58
- item_col: str,
59
- label_col: str,
60
- user_features: pd.DataFrame,
61
- item_features: pd.DataFrame,
62
- num_pos_per_user: int = 5,
63
- num_neg_per_pos: int = 50,
64
- random_seed: int = 2025,
65
- ) -> pd.DataFrame:
69
+ df_all: pd.DataFrame,
70
+ user_col: str,
71
+ item_col: str,
72
+ label_col: str,
73
+ user_features: pd.DataFrame,
74
+ item_features: pd.DataFrame,
75
+ num_pos_per_user: int = 5,
76
+ num_neg_per_pos: int = 50,
77
+ random_seed: int = 2025,
78
+ ) -> pd.DataFrame:
66
79
  """
67
80
  Build evaluation candidates with positive and negative samples for each user.
68
81
 
@@ -111,11 +124,10 @@ def build_eval_candidates(
111
124
  eval_df = eval_df.merge(item_features, on=item_col, how='left')
112
125
  return eval_df
113
126
 
114
-
115
127
  def get_user_ids(
116
- data: Any,
117
- id_columns: list[str] | str | None = None
118
- ) -> np.ndarray | None:
128
+ data: Any,
129
+ id_columns: list[str] | str | None = None
130
+ ) -> np.ndarray | None:
119
131
  """
120
132
  Extract user IDs from various data structures.
121
133
 
@@ -15,15 +15,15 @@ import pyarrow.parquet as pq
15
15
  from pathlib import Path
16
16
  from typing import cast
17
17
 
18
- from torch.utils.data import DataLoader, Dataset, IterableDataset
19
- from nextrec.data.preprocessor import DataProcessor
18
+ from nextrec.basic.loggers import colorize
20
19
  from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSet
20
+ from nextrec.data.preprocessor import DataProcessor
21
+ from torch.utils.data import DataLoader, Dataset, IterableDataset
21
22
 
22
- from nextrec.basic.loggers import colorize
23
- from nextrec.data.data_processing import get_column_data
24
- from nextrec.data.batch_utils import collate_fn
25
- from nextrec.utils.file import resolve_file_paths, read_table
26
23
  from nextrec.utils.tensor import to_tensor
24
+ from nextrec.utils.file import resolve_file_paths, read_table
25
+ from nextrec.data.batch_utils import collate_fn
26
+ from nextrec.data.data_processing import get_column_data
27
27
 
28
28
  class TensorDictDataset(Dataset):
29
29
  """Dataset returning sample-level dicts matching the unified batch schema."""
@@ -118,6 +118,18 @@ class RecDataLoader(FeatureSet):
118
118
  target: list[str] | None | str = None,
119
119
  id_columns: str | list[str] | None = None,
120
120
  processor: DataProcessor | None = None):
121
+ """
122
+ RecDataLoader is a unified dataloader for supporting in-memory and streaming data.
123
+ Basemodel will accept RecDataLoader to create dataloaders for training/evaluation/prediction.
124
+
125
+ Args:
126
+ dense_features: list of DenseFeature definitions
127
+ sparse_features: list of SparseFeature definitions
128
+ sequence_features: list of SequenceFeature definitions
129
+ target: target column name(s), e.g. 'label' or ['ctr', 'ctcvr']
130
+ id_columns: id column name(s) to carry through (not used for model inputs), e.g. 'user_id' or ['user_id', 'item_id']
131
+ processor: an instance of DataProcessor, if provided, will be used to transform data before creating tensors.
132
+ """
121
133
  self.processor = processor
122
134
  self.set_all_features(dense_features, sparse_features, sequence_features, target, id_columns)
123
135
 
@@ -126,20 +138,40 @@ class RecDataLoader(FeatureSet):
126
138
  batch_size: int = 32,
127
139
  shuffle: bool = True,
128
140
  load_full: bool = True,
129
- chunk_size: int = 10000) -> DataLoader:
141
+ chunk_size: int = 10000,
142
+ num_workers: int = 0,
143
+ sampler = None) -> DataLoader:
144
+ """
145
+ Create a DataLoader from various data sources.
146
+
147
+ Args:
148
+ data: Data source, can be a dict, pd.DataFrame, file path (str), or existing DataLoader.
149
+ batch_size: Batch size for DataLoader.
150
+ shuffle: Whether to shuffle the data (ignored in streaming mode).
151
+ load_full: If True, load full data into memory; if False, use streaming mode for large files.
152
+ chunk_size: Chunk size for streaming mode (number of rows per chunk).
153
+ num_workers: Number of worker processes for data loading.
154
+ sampler: Optional sampler for DataLoader, only used for distributed training.
155
+ Returns:
156
+ DataLoader instance.
157
+ """
158
+
130
159
  if isinstance(data, DataLoader):
131
160
  return data
132
161
  elif isinstance(data, (str, os.PathLike)):
133
- return self.create_from_path(path=data, batch_size=batch_size, shuffle=shuffle, load_full=load_full, chunk_size=chunk_size)
162
+ return self.create_from_path(path=data, batch_size=batch_size, shuffle=shuffle, load_full=load_full, chunk_size=chunk_size, num_workers=num_workers)
134
163
  elif isinstance(data, (dict, pd.DataFrame)):
135
- return self.create_from_memory(data=data, batch_size=batch_size, shuffle=shuffle)
164
+ return self.create_from_memory(data=data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, sampler=sampler)
136
165
  else:
137
166
  raise ValueError(f"[RecDataLoader Error] Unsupported data type: {type(data)}")
138
167
 
139
168
  def create_from_memory(self,
140
169
  data: dict | pd.DataFrame,
141
170
  batch_size: int,
142
- shuffle: bool) -> DataLoader:
171
+ shuffle: bool,
172
+ num_workers: int = 0,
173
+ sampler=None) -> DataLoader:
174
+
143
175
  raw_data = data
144
176
 
145
177
  if self.processor is not None:
@@ -150,14 +182,15 @@ class RecDataLoader(FeatureSet):
150
182
  if tensors is None:
151
183
  raise ValueError("[RecDataLoader Error] No valid tensors could be built from the provided data.")
152
184
  dataset = TensorDictDataset(tensors)
153
- return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
185
+ return DataLoader(dataset, batch_size=batch_size, shuffle=False if sampler is not None else shuffle, sampler=sampler, collate_fn=collate_fn, num_workers=num_workers)
154
186
 
155
187
  def create_from_path(self,
156
188
  path: str,
157
189
  batch_size: int,
158
190
  shuffle: bool,
159
191
  load_full: bool,
160
- chunk_size: int = 10000) -> DataLoader:
192
+ chunk_size: int = 10000,
193
+ num_workers: int = 0) -> DataLoader:
161
194
  file_paths, file_type = resolve_file_paths(str(Path(path)))
162
195
  # Load full data into memory
163
196
  if load_full:
@@ -169,6 +202,7 @@ class RecDataLoader(FeatureSet):
169
202
  except OSError:
170
203
  pass
171
204
  try:
205
+ df = read_table(file_path, file_type=file_type)
172
206
  dfs.append(df)
173
207
  except MemoryError as exc:
174
208
  raise MemoryError(f"[RecDataLoader Error] Out of memory while reading {file_path}. Consider using load_full=False with streaming.") from exc
@@ -176,22 +210,23 @@ class RecDataLoader(FeatureSet):
176
210
  combined_df = pd.concat(dfs, ignore_index=True)
177
211
  except MemoryError as exc:
178
212
  raise MemoryError(f"[RecDataLoader Error] Out of memory while concatenating loaded data (approx {total_bytes / (1024**3):.2f} GB). Use load_full=False to stream or reduce chunk_size.") from exc
179
- return self.create_from_memory(combined_df, batch_size, shuffle,)
213
+ return self.create_from_memory(combined_df, batch_size, shuffle, num_workers=num_workers)
180
214
  else:
181
- return self.load_files_streaming(file_paths, file_type, batch_size, chunk_size, shuffle)
215
+ return self.load_files_streaming(file_paths, file_type, batch_size, chunk_size, shuffle, num_workers=num_workers)
182
216
 
183
217
  def load_files_streaming(self,
184
218
  file_paths: list[str],
185
219
  file_type: str,
186
220
  batch_size: int,
187
221
  chunk_size: int,
188
- shuffle: bool) -> DataLoader:
222
+ shuffle: bool,
223
+ num_workers: int = 0) -> DataLoader:
189
224
  if shuffle:
190
225
  logging.info("[RecDataLoader Info] Shuffle is ignored in streaming mode (IterableDataset).")
191
226
  if batch_size != 1:
192
227
  logging.info("[RecDataLoader Info] Streaming mode enforces batch_size=1; tune chunk_size to control memory/throughput.")
193
228
  dataset = FileDataset(file_paths=file_paths, dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target_columns=self.target_columns, id_columns=self.id_columns, chunk_size=chunk_size, file_type=file_type, processor=self.processor)
194
- return DataLoader(dataset, batch_size=1, collate_fn=collate_fn)
229
+ return DataLoader(dataset, batch_size=1, collate_fn=collate_fn, num_workers=num_workers)
195
230
 
196
231
  def normalize_sequence_column(column, feature: SequenceFeature) -> np.ndarray:
197
232
  if isinstance(column, pd.Series):
@@ -1,5 +0,0 @@
1
- from .hstu import HSTU
2
-
3
- __all__ = [
4
- "HSTU",
5
- ]
@@ -255,7 +255,7 @@ class HSTU(BaseModel):
255
255
  return "HSTU"
256
256
 
257
257
  @property
258
- def task_type(self) -> str:
258
+ def default_task(self) -> str:
259
259
  return "multiclass"
260
260
 
261
261
  def __init__(
@@ -275,6 +275,7 @@ class HSTU(BaseModel):
275
275
 
276
276
  tie_embeddings: bool = True,
277
277
  target: Optional[list[str] | str] = None,
278
+ task: str | list[str] | None = None,
278
279
  optimizer: str = "adam",
279
280
  optimizer_params: Optional[dict] = None,
280
281
  scheduler: Optional[str] = None,
@@ -307,7 +308,7 @@ class HSTU(BaseModel):
307
308
  sparse_features=sparse_features,
308
309
  sequence_features=sequence_features,
309
310
  target=target,
310
- task=self.task_type,
311
+ task=task or self.default_task,
311
312
  device=device,
312
313
  embedding_l1_reg=embedding_l1_reg,
313
314
  dense_l1_reg=dense_l1_reg,
@@ -1,13 +0,0 @@
1
- from .dssm import DSSM
2
- from .dssm_v2 import DSSM_v2
3
- from .youtube_dnn import YoutubeDNN
4
- from .mind import MIND
5
- from .sdm import SDM
6
-
7
- __all__ = [
8
- 'DSSM',
9
- 'DSSM_v2',
10
- 'YoutubeDNN',
11
- 'MIND',
12
- 'SDM',
13
- ]
@@ -73,7 +73,6 @@ class DSSM(BaseMatchModel):
73
73
  dense_l1_reg=dense_l1_reg,
74
74
  embedding_l2_reg=embedding_l2_reg,
75
75
  dense_l2_reg=dense_l2_reg,
76
- early_stop_patience=early_stop_patience,
77
76
  **kwargs
78
77
  )
79
78
 
@@ -68,7 +68,6 @@ class DSSM_v2(BaseMatchModel):
68
68
  dense_l1_reg=dense_l1_reg,
69
69
  embedding_l2_reg=embedding_l2_reg,
70
70
  dense_l2_reg=dense_l2_reg,
71
- early_stop_patience=early_stop_patience,
72
71
  **kwargs
73
72
  )
74
73
 
@@ -184,7 +184,6 @@ class MIND(BaseMatchModel):
184
184
  dense_l1_reg=dense_l1_reg,
185
185
  embedding_l2_reg=embedding_l2_reg,
186
186
  dense_l2_reg=dense_l2_reg,
187
- early_stop_patience=early_stop_patience,
188
187
  **kwargs
189
188
  )
190
189
 
@@ -76,7 +76,6 @@ class SDM(BaseMatchModel):
76
76
  dense_l1_reg=dense_l1_reg,
77
77
  embedding_l2_reg=embedding_l2_reg,
78
78
  dense_l2_reg=dense_l2_reg,
79
- early_stop_patience=early_stop_patience,
80
79
  **kwargs
81
80
  )
82
81
 
@@ -73,7 +73,6 @@ class YoutubeDNN(BaseMatchModel):
73
73
  dense_l1_reg=dense_l1_reg,
74
74
  embedding_l2_reg=embedding_l2_reg,
75
75
  dense_l2_reg=dense_l2_reg,
76
- early_stop_patience=early_stop_patience,
77
76
  **kwargs
78
77
  )
79
78
 
File without changes
@@ -64,10 +64,9 @@ class ESMM(BaseModel):
64
64
  @property
65
65
  def model_name(self):
66
66
  return "ESMM"
67
-
67
+
68
68
  @property
69
- def task_type(self):
70
- # ESMM has fixed task types: CTR (binary) and CVR (binary)
69
+ def default_task(self):
71
70
  return ['binary', 'binary']
72
71
 
73
72
  def __init__(self,
@@ -77,7 +76,7 @@ class ESMM(BaseModel):
77
76
  ctr_params: dict,
78
77
  cvr_params: dict,
79
78
  target: list[str] = ['ctr', 'ctcvr'], # Note: ctcvr = ctr * cvr
80
- task: list[str] = ['binary', 'binary'],
79
+ task: list[str] | None = None,
81
80
  optimizer: str = "adam",
82
81
  optimizer_params: dict = {},
83
82
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
@@ -98,13 +97,12 @@ class ESMM(BaseModel):
98
97
  sparse_features=sparse_features,
99
98
  sequence_features=sequence_features,
100
99
  target=target,
101
- task=task, # Both CTR and CTCVR are binary classification
100
+ task=task or self.default_task, # Both CTR and CTCVR are binary classification
102
101
  device=device,
103
102
  embedding_l1_reg=embedding_l1_reg,
104
103
  dense_l1_reg=dense_l1_reg,
105
104
  embedding_l2_reg=embedding_l2_reg,
106
105
  dense_l2_reg=dense_l2_reg,
107
- early_stop_patience=20,
108
106
  **kwargs
109
107
  )
110
108
 
@@ -126,7 +124,7 @@ class ESMM(BaseModel):
126
124
 
127
125
  # CVR tower
128
126
  self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
129
- self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1, 1])
127
+ self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1, 1])
130
128
  # Register regularization weights
131
129
  self.register_regularization_weights(embedding_attr='embedding', include_modules=['ctr_tower', 'cvr_tower'])
132
130
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
@@ -65,8 +65,11 @@ class MMOE(BaseModel):
65
65
  return "MMOE"
66
66
 
67
67
  @property
68
- def task_type(self):
69
- return self.task if isinstance(self.task, list) else [self.task]
68
+ def default_task(self):
69
+ num_tasks = getattr(self, "num_tasks", None)
70
+ if num_tasks is not None and num_tasks > 0:
71
+ return ['binary'] * num_tasks
72
+ return ['binary']
70
73
 
71
74
  def __init__(self,
72
75
  dense_features: list[DenseFeature]=[],
@@ -76,7 +79,7 @@ class MMOE(BaseModel):
76
79
  num_experts: int=3,
77
80
  tower_params_list: list[dict]=[],
78
81
  target: list[str]=[],
79
- task: str | list[str] = 'binary',
82
+ task: str | list[str] | None = None,
80
83
  optimizer: str = "adam",
81
84
  optimizer_params: dict = {},
82
85
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
@@ -88,18 +91,19 @@ class MMOE(BaseModel):
88
91
  dense_l2_reg=1e-4,
89
92
  **kwargs):
90
93
 
94
+ self.num_tasks = len(target)
95
+
91
96
  super(MMOE, self).__init__(
92
97
  dense_features=dense_features,
93
98
  sparse_features=sparse_features,
94
99
  sequence_features=sequence_features,
95
100
  target=target,
96
- task=task,
101
+ task=task or self.default_task,
97
102
  device=device,
98
103
  embedding_l1_reg=embedding_l1_reg,
99
104
  dense_l1_reg=dense_l1_reg,
100
105
  embedding_l2_reg=embedding_l2_reg,
101
106
  dense_l2_reg=dense_l2_reg,
102
- early_stop_patience=20,
103
107
  **kwargs
104
108
  )
105
109
 
@@ -144,7 +148,7 @@ class MMOE(BaseModel):
144
148
  for tower_params in tower_params_list:
145
149
  tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
146
150
  self.towers.append(tower)
147
- self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1] * self.num_tasks)
151
+ self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1] * self.num_tasks)
148
152
  # Register regularization weights
149
153
  self.register_regularization_weights(embedding_attr='embedding', include_modules=['experts', 'gates', 'towers'])
150
154
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params,)
@@ -159,8 +159,11 @@ class PLE(BaseModel):
159
159
  return "PLE"
160
160
 
161
161
  @property
162
- def task_type(self):
163
- return self.task if isinstance(self.task, list) else [self.task]
162
+ def default_task(self):
163
+ num_tasks = getattr(self, "num_tasks", None)
164
+ if num_tasks is not None and num_tasks > 0:
165
+ return ['binary'] * num_tasks
166
+ return ['binary']
164
167
 
165
168
  def __init__(self,
166
169
  dense_features: list[DenseFeature],
@@ -173,7 +176,7 @@ class PLE(BaseModel):
173
176
  num_levels: int,
174
177
  tower_params_list: list[dict],
175
178
  target: list[str],
176
- task: str | list[str] = 'binary',
179
+ task: str | list[str] | None = None,
177
180
  optimizer: str = "adam",
178
181
  optimizer_params: dict | None = None,
179
182
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
@@ -185,18 +188,19 @@ class PLE(BaseModel):
185
188
  dense_l2_reg=1e-4,
186
189
  **kwargs):
187
190
 
191
+ self.num_tasks = len(target)
192
+
188
193
  super(PLE, self).__init__(
189
194
  dense_features=dense_features,
190
195
  sparse_features=sparse_features,
191
196
  sequence_features=sequence_features,
192
197
  target=target,
193
- task=task,
198
+ task=task or self.default_task,
194
199
  device=device,
195
200
  embedding_l1_reg=embedding_l1_reg,
196
201
  dense_l1_reg=dense_l1_reg,
197
202
  embedding_l2_reg=embedding_l2_reg,
198
203
  dense_l2_reg=dense_l2_reg,
199
- early_stop_patience=20,
200
204
  **kwargs
201
205
  )
202
206
 
@@ -247,7 +251,7 @@ class PLE(BaseModel):
247
251
  for tower_params in tower_params_list:
248
252
  tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
249
253
  self.towers.append(tower)
250
- self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1] * self.num_tasks)
254
+ self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1] * self.num_tasks)
251
255
  # Register regularization weights
252
256
  self.register_regularization_weights(embedding_attr='embedding', include_modules=['cgc_layers', 'towers'])
253
257
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=self.loss, loss_params=loss_params)
@@ -261,8 +261,11 @@ class POSO(BaseModel):
261
261
  return "POSO"
262
262
 
263
263
  @property
264
- def task_type(self) -> list[str]:
265
- return self.task if isinstance(self.task, list) else [self.task]
264
+ def default_task(self) -> list[str]:
265
+ num_tasks = getattr(self, "num_tasks", None)
266
+ if num_tasks is not None and num_tasks > 0:
267
+ return ["binary"] * num_tasks
268
+ return ["binary"]
266
269
 
267
270
  def __init__(
268
271
  self,
@@ -274,7 +277,7 @@ class POSO(BaseModel):
274
277
  pc_sequence_features: list[SequenceFeature] | None,
275
278
  tower_params_list: list[dict],
276
279
  target: list[str],
277
- task: str | list[str] = "binary",
280
+ task: str | list[str] | None = None,
278
281
  architecture: str = "mlp",
279
282
  # POSO gating defaults
280
283
  gate_hidden_dim: int = 32,
@@ -307,6 +310,7 @@ class POSO(BaseModel):
307
310
  self.pc_dense_features = list(pc_dense_features or [])
308
311
  self.pc_sparse_features = list(pc_sparse_features or [])
309
312
  self.pc_sequence_features = list(pc_sequence_features or [])
313
+ self.num_tasks = len(target)
310
314
 
311
315
  if not self.pc_dense_features and not self.pc_sparse_features and not self.pc_sequence_features:
312
316
  raise ValueError("POSO requires at least one PC feature for personalization.")
@@ -320,13 +324,12 @@ class POSO(BaseModel):
320
324
  sparse_features=sparse_features,
321
325
  sequence_features=sequence_features,
322
326
  target=target,
323
- task=task,
327
+ task=task or self.default_task,
324
328
  device=device,
325
329
  embedding_l1_reg=embedding_l1_reg,
326
330
  dense_l1_reg=dense_l1_reg,
327
331
  embedding_l2_reg=embedding_l2_reg,
328
332
  dense_l2_reg=dense_l2_reg,
329
- early_stop_patience=20,
330
333
  **kwargs,
331
334
  )
332
335
 
@@ -387,7 +390,7 @@ class POSO(BaseModel):
387
390
  )
388
391
  self.towers = nn.ModuleList([MLP(input_dim=self.mmoe.expert_output_dim, output_layer=True, **tower_params,) for tower_params in tower_params_list])
389
392
  self.tower_heads = None
390
- self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1] * self.num_tasks,)
393
+ self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1] * self.num_tasks,)
391
394
  include_modules = ["towers", "tower_heads"] if self.architecture == "mlp" else ["mmoe", "towers"]
392
395
  self.register_regularization_weights(embedding_attr="embedding", include_modules=include_modules)
393
396
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
@@ -53,9 +53,11 @@ class ShareBottom(BaseModel):
53
53
  return "ShareBottom"
54
54
 
55
55
  @property
56
- def task_type(self):
57
- # Multi-task model, return list of task types
58
- return self.task if isinstance(self.task, list) else [self.task]
56
+ def default_task(self):
57
+ num_tasks = getattr(self, "num_tasks", None)
58
+ if num_tasks is not None and num_tasks > 0:
59
+ return ['binary'] * num_tasks
60
+ return ['binary']
59
61
 
60
62
  def __init__(self,
61
63
  dense_features: list[DenseFeature],
@@ -64,7 +66,7 @@ class ShareBottom(BaseModel):
64
66
  bottom_params: dict,
65
67
  tower_params_list: list[dict],
66
68
  target: list[str],
67
- task: str | list[str] = 'binary',
69
+ task: str | list[str] | None = None,
68
70
  optimizer: str = "adam",
69
71
  optimizer_params: dict = {},
70
72
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
@@ -76,18 +78,19 @@ class ShareBottom(BaseModel):
76
78
  dense_l2_reg=1e-4,
77
79
  **kwargs):
78
80
 
81
+ self.num_tasks = len(target)
82
+
79
83
  super(ShareBottom, self).__init__(
80
84
  dense_features=dense_features,
81
85
  sparse_features=sparse_features,
82
86
  sequence_features=sequence_features,
83
87
  target=target,
84
- task=task,
88
+ task=task or self.default_task,
85
89
  device=device,
86
90
  embedding_l1_reg=embedding_l1_reg,
87
91
  dense_l1_reg=dense_l1_reg,
88
92
  embedding_l2_reg=embedding_l2_reg,
89
93
  dense_l2_reg=dense_l2_reg,
90
- early_stop_patience=20,
91
94
  **kwargs
92
95
  )
93
96
 
@@ -120,7 +123,7 @@ class ShareBottom(BaseModel):
120
123
  for tower_params in tower_params_list:
121
124
  tower = MLP(input_dim=bottom_output_dim, output_layer=True, **tower_params)
122
125
  self.towers.append(tower)
123
- self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1] * self.num_tasks)
126
+ self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1] * self.num_tasks)
124
127
  # Register regularization weights
125
128
  self.register_regularization_weights(embedding_attr='embedding', include_modules=['bottom', 'towers'])
126
129
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)