nextrec 0.3.6__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/layers.py +32 -15
  3. nextrec/basic/model.py +435 -187
  4. nextrec/data/data_processing.py +31 -19
  5. nextrec/data/dataloader.py +40 -10
  6. nextrec/models/generative/hstu.py +3 -2
  7. nextrec/models/match/dssm.py +0 -1
  8. nextrec/models/match/dssm_v2.py +0 -1
  9. nextrec/models/match/mind.py +0 -1
  10. nextrec/models/match/sdm.py +0 -1
  11. nextrec/models/match/youtube_dnn.py +0 -1
  12. nextrec/models/multi_task/esmm.py +5 -7
  13. nextrec/models/multi_task/mmoe.py +10 -6
  14. nextrec/models/multi_task/ple.py +10 -6
  15. nextrec/models/multi_task/poso.py +9 -6
  16. nextrec/models/multi_task/share_bottom.py +10 -7
  17. nextrec/models/ranking/afm.py +113 -21
  18. nextrec/models/ranking/autoint.py +15 -9
  19. nextrec/models/ranking/dcn.py +8 -11
  20. nextrec/models/ranking/deepfm.py +5 -5
  21. nextrec/models/ranking/dien.py +4 -4
  22. nextrec/models/ranking/din.py +4 -4
  23. nextrec/models/ranking/fibinet.py +4 -4
  24. nextrec/models/ranking/fm.py +4 -4
  25. nextrec/models/ranking/masknet.py +4 -5
  26. nextrec/models/ranking/pnn.py +4 -4
  27. nextrec/models/ranking/widedeep.py +4 -4
  28. nextrec/models/ranking/xdeepfm.py +4 -4
  29. nextrec/utils/__init__.py +7 -3
  30. nextrec/utils/device.py +30 -0
  31. nextrec/utils/distributed.py +114 -0
  32. nextrec/utils/synthetic_data.py +413 -0
  33. {nextrec-0.3.6.dist-info → nextrec-0.4.1.dist-info}/METADATA +15 -5
  34. nextrec-0.4.1.dist-info/RECORD +66 -0
  35. nextrec-0.3.6.dist-info/RECORD +0 -64
  36. {nextrec-0.3.6.dist-info → nextrec-0.4.1.dist-info}/WHEEL +0 -0
  37. {nextrec-0.3.6.dist-info → nextrec-0.4.1.dist-info}/licenses/LICENSE +0 -0
@@ -11,7 +11,10 @@ import pandas as pd
11
11
  from typing import Any, Mapping
12
12
 
13
13
 
14
- def get_column_data(data: dict | pd.DataFrame, name: str):
14
+ def get_column_data(
15
+ data: dict | pd.DataFrame,
16
+ name: str):
17
+
15
18
  if isinstance(data, dict):
16
19
  return data[name] if name in data else None
17
20
  elif isinstance(data, pd.DataFrame):
@@ -24,10 +27,11 @@ def get_column_data(data: dict | pd.DataFrame, name: str):
24
27
  raise KeyError(f"Unsupported data type for extracting column {name}")
25
28
 
26
29
  def split_dict_random(
27
- data_dict: dict,
28
- test_size: float = 0.2,
29
- random_state: int | None = None
30
- ):
30
+ data_dict: dict,
31
+ test_size: float = 0.2,
32
+ random_state: int | None = None
33
+ ):
34
+
31
35
  lengths = [len(v) for v in data_dict.values()]
32
36
  if len(set(lengths)) != 1:
33
37
  raise ValueError(f"Length mismatch: {lengths}")
@@ -51,18 +55,27 @@ def split_dict_random(
51
55
  test_dict = {k: take(v, test_idx) for k, v in data_dict.items()}
52
56
  return train_dict, test_dict
53
57
 
58
+ def split_data(
59
+ df: pd.DataFrame,
60
+ test_size: float = 0.2
61
+ ) -> tuple[pd.DataFrame, pd.DataFrame]:
62
+
63
+ split_idx = int(len(df) * (1 - test_size))
64
+ train_df = df.iloc[:split_idx].reset_index(drop=True)
65
+ valid_df = df.iloc[split_idx:].reset_index(drop=True)
66
+ return train_df, valid_df
54
67
 
55
68
  def build_eval_candidates(
56
- df_all: pd.DataFrame,
57
- user_col: str,
58
- item_col: str,
59
- label_col: str,
60
- user_features: pd.DataFrame,
61
- item_features: pd.DataFrame,
62
- num_pos_per_user: int = 5,
63
- num_neg_per_pos: int = 50,
64
- random_seed: int = 2025,
65
- ) -> pd.DataFrame:
69
+ df_all: pd.DataFrame,
70
+ user_col: str,
71
+ item_col: str,
72
+ label_col: str,
73
+ user_features: pd.DataFrame,
74
+ item_features: pd.DataFrame,
75
+ num_pos_per_user: int = 5,
76
+ num_neg_per_pos: int = 50,
77
+ random_seed: int = 2025,
78
+ ) -> pd.DataFrame:
66
79
  """
67
80
  Build evaluation candidates with positive and negative samples for each user.
68
81
 
@@ -111,11 +124,10 @@ def build_eval_candidates(
111
124
  eval_df = eval_df.merge(item_features, on=item_col, how='left')
112
125
  return eval_df
113
126
 
114
-
115
127
  def get_user_ids(
116
- data: Any,
117
- id_columns: list[str] | str | None = None
118
- ) -> np.ndarray | None:
128
+ data: Any,
129
+ id_columns: list[str] | str | None = None
130
+ ) -> np.ndarray | None:
119
131
  """
120
132
  Extract user IDs from various data structures.
121
133
 
@@ -15,15 +15,15 @@ import pyarrow.parquet as pq
15
15
  from pathlib import Path
16
16
  from typing import cast
17
17
 
18
- from torch.utils.data import DataLoader, Dataset, IterableDataset
19
- from nextrec.data.preprocessor import DataProcessor
18
+ from nextrec.basic.loggers import colorize
20
19
  from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSet
20
+ from nextrec.data.preprocessor import DataProcessor
21
+ from torch.utils.data import DataLoader, Dataset, IterableDataset
21
22
 
22
- from nextrec.basic.loggers import colorize
23
- from nextrec.data.data_processing import get_column_data
24
- from nextrec.data.batch_utils import collate_fn
25
- from nextrec.utils.file import resolve_file_paths, read_table
26
23
  from nextrec.utils.tensor import to_tensor
24
+ from nextrec.utils.file import resolve_file_paths, read_table
25
+ from nextrec.data.batch_utils import collate_fn
26
+ from nextrec.data.data_processing import get_column_data
27
27
 
28
28
  class TensorDictDataset(Dataset):
29
29
  """Dataset returning sample-level dicts matching the unified batch schema."""
@@ -118,6 +118,18 @@ class RecDataLoader(FeatureSet):
118
118
  target: list[str] | None | str = None,
119
119
  id_columns: str | list[str] | None = None,
120
120
  processor: DataProcessor | None = None):
121
+ """
122
+ RecDataLoader is a unified dataloader for supporting in-memory and streaming data.
123
+ Basemodel will accept RecDataLoader to create dataloaders for training/evaluation/prediction.
124
+
125
+ Args:
126
+ dense_features: list of DenseFeature definitions
127
+ sparse_features: list of SparseFeature definitions
128
+ sequence_features: list of SequenceFeature definitions
129
+ target: target column name(s), e.g. 'label' or ['ctr', 'ctcvr']
130
+ id_columns: id column name(s) to carry through (not used for model inputs), e.g. 'user_id' or ['user_id', 'item_id']
131
+ processor: an instance of DataProcessor, if provided, will be used to transform data before creating tensors.
132
+ """
121
133
  self.processor = processor
122
134
  self.set_all_features(dense_features, sparse_features, sequence_features, target, id_columns)
123
135
 
@@ -127,13 +139,29 @@ class RecDataLoader(FeatureSet):
127
139
  shuffle: bool = True,
128
140
  load_full: bool = True,
129
141
  chunk_size: int = 10000,
130
- num_workers: int = 0) -> DataLoader:
142
+ num_workers: int = 0,
143
+ sampler = None) -> DataLoader:
144
+ """
145
+ Create a DataLoader from various data sources.
146
+
147
+ Args:
148
+ data: Data source, can be a dict, pd.DataFrame, file path (str), or existing DataLoader.
149
+ batch_size: Batch size for DataLoader.
150
+ shuffle: Whether to shuffle the data (ignored in streaming mode).
151
+ load_full: If True, load full data into memory; if False, use streaming mode for large files.
152
+ chunk_size: Chunk size for streaming mode (number of rows per chunk).
153
+ num_workers: Number of worker processes for data loading.
154
+ sampler: Optional sampler for DataLoader, only used for distributed training.
155
+ Returns:
156
+ DataLoader instance.
157
+ """
158
+
131
159
  if isinstance(data, DataLoader):
132
160
  return data
133
161
  elif isinstance(data, (str, os.PathLike)):
134
162
  return self.create_from_path(path=data, batch_size=batch_size, shuffle=shuffle, load_full=load_full, chunk_size=chunk_size, num_workers=num_workers)
135
163
  elif isinstance(data, (dict, pd.DataFrame)):
136
- return self.create_from_memory(data=data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
164
+ return self.create_from_memory(data=data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, sampler=sampler)
137
165
  else:
138
166
  raise ValueError(f"[RecDataLoader Error] Unsupported data type: {type(data)}")
139
167
 
@@ -141,7 +169,9 @@ class RecDataLoader(FeatureSet):
141
169
  data: dict | pd.DataFrame,
142
170
  batch_size: int,
143
171
  shuffle: bool,
144
- num_workers: int = 0) -> DataLoader:
172
+ num_workers: int = 0,
173
+ sampler=None) -> DataLoader:
174
+
145
175
  raw_data = data
146
176
 
147
177
  if self.processor is not None:
@@ -152,7 +182,7 @@ class RecDataLoader(FeatureSet):
152
182
  if tensors is None:
153
183
  raise ValueError("[RecDataLoader Error] No valid tensors could be built from the provided data.")
154
184
  dataset = TensorDictDataset(tensors)
155
- return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn, num_workers=num_workers)
185
+ return DataLoader(dataset, batch_size=batch_size, shuffle=False if sampler is not None else shuffle, sampler=sampler, collate_fn=collate_fn, num_workers=num_workers)
156
186
 
157
187
  def create_from_path(self,
158
188
  path: str,
@@ -255,7 +255,7 @@ class HSTU(BaseModel):
255
255
  return "HSTU"
256
256
 
257
257
  @property
258
- def task_type(self) -> str:
258
+ def default_task(self) -> str:
259
259
  return "multiclass"
260
260
 
261
261
  def __init__(
@@ -275,6 +275,7 @@ class HSTU(BaseModel):
275
275
 
276
276
  tie_embeddings: bool = True,
277
277
  target: Optional[list[str] | str] = None,
278
+ task: str | list[str] | None = None,
278
279
  optimizer: str = "adam",
279
280
  optimizer_params: Optional[dict] = None,
280
281
  scheduler: Optional[str] = None,
@@ -307,7 +308,7 @@ class HSTU(BaseModel):
307
308
  sparse_features=sparse_features,
308
309
  sequence_features=sequence_features,
309
310
  target=target,
310
- task=self.task_type,
311
+ task=task or self.default_task,
311
312
  device=device,
312
313
  embedding_l1_reg=embedding_l1_reg,
313
314
  dense_l1_reg=dense_l1_reg,
@@ -73,7 +73,6 @@ class DSSM(BaseMatchModel):
73
73
  dense_l1_reg=dense_l1_reg,
74
74
  embedding_l2_reg=embedding_l2_reg,
75
75
  dense_l2_reg=dense_l2_reg,
76
- early_stop_patience=early_stop_patience,
77
76
  **kwargs
78
77
  )
79
78
 
@@ -68,7 +68,6 @@ class DSSM_v2(BaseMatchModel):
68
68
  dense_l1_reg=dense_l1_reg,
69
69
  embedding_l2_reg=embedding_l2_reg,
70
70
  dense_l2_reg=dense_l2_reg,
71
- early_stop_patience=early_stop_patience,
72
71
  **kwargs
73
72
  )
74
73
 
@@ -184,7 +184,6 @@ class MIND(BaseMatchModel):
184
184
  dense_l1_reg=dense_l1_reg,
185
185
  embedding_l2_reg=embedding_l2_reg,
186
186
  dense_l2_reg=dense_l2_reg,
187
- early_stop_patience=early_stop_patience,
188
187
  **kwargs
189
188
  )
190
189
 
@@ -76,7 +76,6 @@ class SDM(BaseMatchModel):
76
76
  dense_l1_reg=dense_l1_reg,
77
77
  embedding_l2_reg=embedding_l2_reg,
78
78
  dense_l2_reg=dense_l2_reg,
79
- early_stop_patience=early_stop_patience,
80
79
  **kwargs
81
80
  )
82
81
 
@@ -73,7 +73,6 @@ class YoutubeDNN(BaseMatchModel):
73
73
  dense_l1_reg=dense_l1_reg,
74
74
  embedding_l2_reg=embedding_l2_reg,
75
75
  dense_l2_reg=dense_l2_reg,
76
- early_stop_patience=early_stop_patience,
77
76
  **kwargs
78
77
  )
79
78
 
@@ -64,10 +64,9 @@ class ESMM(BaseModel):
64
64
  @property
65
65
  def model_name(self):
66
66
  return "ESMM"
67
-
67
+
68
68
  @property
69
- def task_type(self):
70
- # ESMM has fixed task types: CTR (binary) and CVR (binary)
69
+ def default_task(self):
71
70
  return ['binary', 'binary']
72
71
 
73
72
  def __init__(self,
@@ -77,7 +76,7 @@ class ESMM(BaseModel):
77
76
  ctr_params: dict,
78
77
  cvr_params: dict,
79
78
  target: list[str] = ['ctr', 'ctcvr'], # Note: ctcvr = ctr * cvr
80
- task: list[str] = ['binary', 'binary'],
79
+ task: list[str] | None = None,
81
80
  optimizer: str = "adam",
82
81
  optimizer_params: dict = {},
83
82
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
@@ -98,13 +97,12 @@ class ESMM(BaseModel):
98
97
  sparse_features=sparse_features,
99
98
  sequence_features=sequence_features,
100
99
  target=target,
101
- task=task, # Both CTR and CTCVR are binary classification
100
+ task=task or self.default_task, # Both CTR and CTCVR are binary classification
102
101
  device=device,
103
102
  embedding_l1_reg=embedding_l1_reg,
104
103
  dense_l1_reg=dense_l1_reg,
105
104
  embedding_l2_reg=embedding_l2_reg,
106
105
  dense_l2_reg=dense_l2_reg,
107
- early_stop_patience=20,
108
106
  **kwargs
109
107
  )
110
108
 
@@ -126,7 +124,7 @@ class ESMM(BaseModel):
126
124
 
127
125
  # CVR tower
128
126
  self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
129
- self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1, 1])
127
+ self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1, 1])
130
128
  # Register regularization weights
131
129
  self.register_regularization_weights(embedding_attr='embedding', include_modules=['ctr_tower', 'cvr_tower'])
132
130
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
@@ -65,8 +65,11 @@ class MMOE(BaseModel):
65
65
  return "MMOE"
66
66
 
67
67
  @property
68
- def task_type(self):
69
- return self.task if isinstance(self.task, list) else [self.task]
68
+ def default_task(self):
69
+ num_tasks = getattr(self, "num_tasks", None)
70
+ if num_tasks is not None and num_tasks > 0:
71
+ return ['binary'] * num_tasks
72
+ return ['binary']
70
73
 
71
74
  def __init__(self,
72
75
  dense_features: list[DenseFeature]=[],
@@ -76,7 +79,7 @@ class MMOE(BaseModel):
76
79
  num_experts: int=3,
77
80
  tower_params_list: list[dict]=[],
78
81
  target: list[str]=[],
79
- task: str | list[str] = 'binary',
82
+ task: str | list[str] | None = None,
80
83
  optimizer: str = "adam",
81
84
  optimizer_params: dict = {},
82
85
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
@@ -88,18 +91,19 @@ class MMOE(BaseModel):
88
91
  dense_l2_reg=1e-4,
89
92
  **kwargs):
90
93
 
94
+ self.num_tasks = len(target)
95
+
91
96
  super(MMOE, self).__init__(
92
97
  dense_features=dense_features,
93
98
  sparse_features=sparse_features,
94
99
  sequence_features=sequence_features,
95
100
  target=target,
96
- task=task,
101
+ task=task or self.default_task,
97
102
  device=device,
98
103
  embedding_l1_reg=embedding_l1_reg,
99
104
  dense_l1_reg=dense_l1_reg,
100
105
  embedding_l2_reg=embedding_l2_reg,
101
106
  dense_l2_reg=dense_l2_reg,
102
- early_stop_patience=20,
103
107
  **kwargs
104
108
  )
105
109
 
@@ -144,7 +148,7 @@ class MMOE(BaseModel):
144
148
  for tower_params in tower_params_list:
145
149
  tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
146
150
  self.towers.append(tower)
147
- self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1] * self.num_tasks)
151
+ self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1] * self.num_tasks)
148
152
  # Register regularization weights
149
153
  self.register_regularization_weights(embedding_attr='embedding', include_modules=['experts', 'gates', 'towers'])
150
154
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params,)
@@ -159,8 +159,11 @@ class PLE(BaseModel):
159
159
  return "PLE"
160
160
 
161
161
  @property
162
- def task_type(self):
163
- return self.task if isinstance(self.task, list) else [self.task]
162
+ def default_task(self):
163
+ num_tasks = getattr(self, "num_tasks", None)
164
+ if num_tasks is not None and num_tasks > 0:
165
+ return ['binary'] * num_tasks
166
+ return ['binary']
164
167
 
165
168
  def __init__(self,
166
169
  dense_features: list[DenseFeature],
@@ -173,7 +176,7 @@ class PLE(BaseModel):
173
176
  num_levels: int,
174
177
  tower_params_list: list[dict],
175
178
  target: list[str],
176
- task: str | list[str] = 'binary',
179
+ task: str | list[str] | None = None,
177
180
  optimizer: str = "adam",
178
181
  optimizer_params: dict | None = None,
179
182
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
@@ -185,18 +188,19 @@ class PLE(BaseModel):
185
188
  dense_l2_reg=1e-4,
186
189
  **kwargs):
187
190
 
191
+ self.num_tasks = len(target)
192
+
188
193
  super(PLE, self).__init__(
189
194
  dense_features=dense_features,
190
195
  sparse_features=sparse_features,
191
196
  sequence_features=sequence_features,
192
197
  target=target,
193
- task=task,
198
+ task=task or self.default_task,
194
199
  device=device,
195
200
  embedding_l1_reg=embedding_l1_reg,
196
201
  dense_l1_reg=dense_l1_reg,
197
202
  embedding_l2_reg=embedding_l2_reg,
198
203
  dense_l2_reg=dense_l2_reg,
199
- early_stop_patience=20,
200
204
  **kwargs
201
205
  )
202
206
 
@@ -247,7 +251,7 @@ class PLE(BaseModel):
247
251
  for tower_params in tower_params_list:
248
252
  tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
249
253
  self.towers.append(tower)
250
- self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1] * self.num_tasks)
254
+ self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1] * self.num_tasks)
251
255
  # Register regularization weights
252
256
  self.register_regularization_weights(embedding_attr='embedding', include_modules=['cgc_layers', 'towers'])
253
257
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=self.loss, loss_params=loss_params)
@@ -261,8 +261,11 @@ class POSO(BaseModel):
261
261
  return "POSO"
262
262
 
263
263
  @property
264
- def task_type(self) -> list[str]:
265
- return self.task if isinstance(self.task, list) else [self.task]
264
+ def default_task(self) -> list[str]:
265
+ num_tasks = getattr(self, "num_tasks", None)
266
+ if num_tasks is not None and num_tasks > 0:
267
+ return ["binary"] * num_tasks
268
+ return ["binary"]
266
269
 
267
270
  def __init__(
268
271
  self,
@@ -274,7 +277,7 @@ class POSO(BaseModel):
274
277
  pc_sequence_features: list[SequenceFeature] | None,
275
278
  tower_params_list: list[dict],
276
279
  target: list[str],
277
- task: str | list[str] = "binary",
280
+ task: str | list[str] | None = None,
278
281
  architecture: str = "mlp",
279
282
  # POSO gating defaults
280
283
  gate_hidden_dim: int = 32,
@@ -307,6 +310,7 @@ class POSO(BaseModel):
307
310
  self.pc_dense_features = list(pc_dense_features or [])
308
311
  self.pc_sparse_features = list(pc_sparse_features or [])
309
312
  self.pc_sequence_features = list(pc_sequence_features or [])
313
+ self.num_tasks = len(target)
310
314
 
311
315
  if not self.pc_dense_features and not self.pc_sparse_features and not self.pc_sequence_features:
312
316
  raise ValueError("POSO requires at least one PC feature for personalization.")
@@ -320,13 +324,12 @@ class POSO(BaseModel):
320
324
  sparse_features=sparse_features,
321
325
  sequence_features=sequence_features,
322
326
  target=target,
323
- task=task,
327
+ task=task or self.default_task,
324
328
  device=device,
325
329
  embedding_l1_reg=embedding_l1_reg,
326
330
  dense_l1_reg=dense_l1_reg,
327
331
  embedding_l2_reg=embedding_l2_reg,
328
332
  dense_l2_reg=dense_l2_reg,
329
- early_stop_patience=20,
330
333
  **kwargs,
331
334
  )
332
335
 
@@ -387,7 +390,7 @@ class POSO(BaseModel):
387
390
  )
388
391
  self.towers = nn.ModuleList([MLP(input_dim=self.mmoe.expert_output_dim, output_layer=True, **tower_params,) for tower_params in tower_params_list])
389
392
  self.tower_heads = None
390
- self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1] * self.num_tasks,)
393
+ self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1] * self.num_tasks,)
391
394
  include_modules = ["towers", "tower_heads"] if self.architecture == "mlp" else ["mmoe", "towers"]
392
395
  self.register_regularization_weights(embedding_attr="embedding", include_modules=include_modules)
393
396
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
@@ -53,9 +53,11 @@ class ShareBottom(BaseModel):
53
53
  return "ShareBottom"
54
54
 
55
55
  @property
56
- def task_type(self):
57
- # Multi-task model, return list of task types
58
- return self.task if isinstance(self.task, list) else [self.task]
56
+ def default_task(self):
57
+ num_tasks = getattr(self, "num_tasks", None)
58
+ if num_tasks is not None and num_tasks > 0:
59
+ return ['binary'] * num_tasks
60
+ return ['binary']
59
61
 
60
62
  def __init__(self,
61
63
  dense_features: list[DenseFeature],
@@ -64,7 +66,7 @@ class ShareBottom(BaseModel):
64
66
  bottom_params: dict,
65
67
  tower_params_list: list[dict],
66
68
  target: list[str],
67
- task: str | list[str] = 'binary',
69
+ task: str | list[str] | None = None,
68
70
  optimizer: str = "adam",
69
71
  optimizer_params: dict = {},
70
72
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
@@ -76,18 +78,19 @@ class ShareBottom(BaseModel):
76
78
  dense_l2_reg=1e-4,
77
79
  **kwargs):
78
80
 
81
+ self.num_tasks = len(target)
82
+
79
83
  super(ShareBottom, self).__init__(
80
84
  dense_features=dense_features,
81
85
  sparse_features=sparse_features,
82
86
  sequence_features=sequence_features,
83
87
  target=target,
84
- task=task,
88
+ task=task or self.default_task,
85
89
  device=device,
86
90
  embedding_l1_reg=embedding_l1_reg,
87
91
  dense_l1_reg=dense_l1_reg,
88
92
  embedding_l2_reg=embedding_l2_reg,
89
93
  dense_l2_reg=dense_l2_reg,
90
- early_stop_patience=20,
91
94
  **kwargs
92
95
  )
93
96
 
@@ -120,7 +123,7 @@ class ShareBottom(BaseModel):
120
123
  for tower_params in tower_params_list:
121
124
  tower = MLP(input_dim=bottom_output_dim, output_layer=True, **tower_params)
122
125
  self.towers.append(tower)
123
- self.prediction_layer = PredictionLayer(task_type=self.task_type, task_dims=[1] * self.num_tasks)
126
+ self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1] * self.num_tasks)
124
127
  # Register regularization weights
125
128
  self.register_regularization_weights(embedding_attr='embedding', include_modules=['bottom', 'towers'])
126
129
  self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)