nextrec 0.3.6__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/layers.py +32 -15
- nextrec/basic/model.py +435 -187
- nextrec/data/data_processing.py +31 -19
- nextrec/data/dataloader.py +40 -10
- nextrec/models/generative/hstu.py +3 -2
- nextrec/models/match/dssm.py +0 -1
- nextrec/models/match/dssm_v2.py +0 -1
- nextrec/models/match/mind.py +0 -1
- nextrec/models/match/sdm.py +0 -1
- nextrec/models/match/youtube_dnn.py +0 -1
- nextrec/models/multi_task/esmm.py +5 -7
- nextrec/models/multi_task/mmoe.py +10 -6
- nextrec/models/multi_task/ple.py +10 -6
- nextrec/models/multi_task/poso.py +9 -6
- nextrec/models/multi_task/share_bottom.py +10 -7
- nextrec/models/ranking/afm.py +113 -21
- nextrec/models/ranking/autoint.py +15 -9
- nextrec/models/ranking/dcn.py +8 -11
- nextrec/models/ranking/deepfm.py +5 -5
- nextrec/models/ranking/dien.py +4 -4
- nextrec/models/ranking/din.py +4 -4
- nextrec/models/ranking/fibinet.py +4 -4
- nextrec/models/ranking/fm.py +4 -4
- nextrec/models/ranking/masknet.py +4 -5
- nextrec/models/ranking/pnn.py +4 -4
- nextrec/models/ranking/widedeep.py +4 -4
- nextrec/models/ranking/xdeepfm.py +4 -4
- nextrec/utils/__init__.py +7 -3
- nextrec/utils/device.py +30 -0
- nextrec/utils/distributed.py +114 -0
- nextrec/utils/synthetic_data.py +413 -0
- {nextrec-0.3.6.dist-info → nextrec-0.4.1.dist-info}/METADATA +15 -5
- nextrec-0.4.1.dist-info/RECORD +66 -0
- nextrec-0.3.6.dist-info/RECORD +0 -64
- {nextrec-0.3.6.dist-info → nextrec-0.4.1.dist-info}/WHEEL +0 -0
- {nextrec-0.3.6.dist-info → nextrec-0.4.1.dist-info}/licenses/LICENSE +0 -0
nextrec/data/data_processing.py
CHANGED
|
@@ -11,7 +11,10 @@ import pandas as pd
|
|
|
11
11
|
from typing import Any, Mapping
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
def get_column_data(
|
|
14
|
+
def get_column_data(
|
|
15
|
+
data: dict | pd.DataFrame,
|
|
16
|
+
name: str):
|
|
17
|
+
|
|
15
18
|
if isinstance(data, dict):
|
|
16
19
|
return data[name] if name in data else None
|
|
17
20
|
elif isinstance(data, pd.DataFrame):
|
|
@@ -24,10 +27,11 @@ def get_column_data(data: dict | pd.DataFrame, name: str):
|
|
|
24
27
|
raise KeyError(f"Unsupported data type for extracting column {name}")
|
|
25
28
|
|
|
26
29
|
def split_dict_random(
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
):
|
|
30
|
+
data_dict: dict,
|
|
31
|
+
test_size: float = 0.2,
|
|
32
|
+
random_state: int | None = None
|
|
33
|
+
):
|
|
34
|
+
|
|
31
35
|
lengths = [len(v) for v in data_dict.values()]
|
|
32
36
|
if len(set(lengths)) != 1:
|
|
33
37
|
raise ValueError(f"Length mismatch: {lengths}")
|
|
@@ -51,18 +55,27 @@ def split_dict_random(
|
|
|
51
55
|
test_dict = {k: take(v, test_idx) for k, v in data_dict.items()}
|
|
52
56
|
return train_dict, test_dict
|
|
53
57
|
|
|
58
|
+
def split_data(
|
|
59
|
+
df: pd.DataFrame,
|
|
60
|
+
test_size: float = 0.2
|
|
61
|
+
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
62
|
+
|
|
63
|
+
split_idx = int(len(df) * (1 - test_size))
|
|
64
|
+
train_df = df.iloc[:split_idx].reset_index(drop=True)
|
|
65
|
+
valid_df = df.iloc[split_idx:].reset_index(drop=True)
|
|
66
|
+
return train_df, valid_df
|
|
54
67
|
|
|
55
68
|
def build_eval_candidates(
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
) -> pd.DataFrame:
|
|
69
|
+
df_all: pd.DataFrame,
|
|
70
|
+
user_col: str,
|
|
71
|
+
item_col: str,
|
|
72
|
+
label_col: str,
|
|
73
|
+
user_features: pd.DataFrame,
|
|
74
|
+
item_features: pd.DataFrame,
|
|
75
|
+
num_pos_per_user: int = 5,
|
|
76
|
+
num_neg_per_pos: int = 50,
|
|
77
|
+
random_seed: int = 2025,
|
|
78
|
+
) -> pd.DataFrame:
|
|
66
79
|
"""
|
|
67
80
|
Build evaluation candidates with positive and negative samples for each user.
|
|
68
81
|
|
|
@@ -111,11 +124,10 @@ def build_eval_candidates(
|
|
|
111
124
|
eval_df = eval_df.merge(item_features, on=item_col, how='left')
|
|
112
125
|
return eval_df
|
|
113
126
|
|
|
114
|
-
|
|
115
127
|
def get_user_ids(
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
) -> np.ndarray | None:
|
|
128
|
+
data: Any,
|
|
129
|
+
id_columns: list[str] | str | None = None
|
|
130
|
+
) -> np.ndarray | None:
|
|
119
131
|
"""
|
|
120
132
|
Extract user IDs from various data structures.
|
|
121
133
|
|
nextrec/data/dataloader.py
CHANGED
|
@@ -15,15 +15,15 @@ import pyarrow.parquet as pq
|
|
|
15
15
|
from pathlib import Path
|
|
16
16
|
from typing import cast
|
|
17
17
|
|
|
18
|
-
from
|
|
19
|
-
from nextrec.data.preprocessor import DataProcessor
|
|
18
|
+
from nextrec.basic.loggers import colorize
|
|
20
19
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSet
|
|
20
|
+
from nextrec.data.preprocessor import DataProcessor
|
|
21
|
+
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
|
21
22
|
|
|
22
|
-
from nextrec.basic.loggers import colorize
|
|
23
|
-
from nextrec.data.data_processing import get_column_data
|
|
24
|
-
from nextrec.data.batch_utils import collate_fn
|
|
25
|
-
from nextrec.utils.file import resolve_file_paths, read_table
|
|
26
23
|
from nextrec.utils.tensor import to_tensor
|
|
24
|
+
from nextrec.utils.file import resolve_file_paths, read_table
|
|
25
|
+
from nextrec.data.batch_utils import collate_fn
|
|
26
|
+
from nextrec.data.data_processing import get_column_data
|
|
27
27
|
|
|
28
28
|
class TensorDictDataset(Dataset):
|
|
29
29
|
"""Dataset returning sample-level dicts matching the unified batch schema."""
|
|
@@ -118,6 +118,18 @@ class RecDataLoader(FeatureSet):
|
|
|
118
118
|
target: list[str] | None | str = None,
|
|
119
119
|
id_columns: str | list[str] | None = None,
|
|
120
120
|
processor: DataProcessor | None = None):
|
|
121
|
+
"""
|
|
122
|
+
RecDataLoader is a unified dataloader for supporting in-memory and streaming data.
|
|
123
|
+
Basemodel will accept RecDataLoader to create dataloaders for training/evaluation/prediction.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
dense_features: list of DenseFeature definitions
|
|
127
|
+
sparse_features: list of SparseFeature definitions
|
|
128
|
+
sequence_features: list of SequenceFeature definitions
|
|
129
|
+
target: target column name(s), e.g. 'label' or ['ctr', 'ctcvr']
|
|
130
|
+
id_columns: id column name(s) to carry through (not used for model inputs), e.g. 'user_id' or ['user_id', 'item_id']
|
|
131
|
+
processor: an instance of DataProcessor, if provided, will be used to transform data before creating tensors.
|
|
132
|
+
"""
|
|
121
133
|
self.processor = processor
|
|
122
134
|
self.set_all_features(dense_features, sparse_features, sequence_features, target, id_columns)
|
|
123
135
|
|
|
@@ -127,13 +139,29 @@ class RecDataLoader(FeatureSet):
|
|
|
127
139
|
shuffle: bool = True,
|
|
128
140
|
load_full: bool = True,
|
|
129
141
|
chunk_size: int = 10000,
|
|
130
|
-
num_workers: int = 0
|
|
142
|
+
num_workers: int = 0,
|
|
143
|
+
sampler = None) -> DataLoader:
|
|
144
|
+
"""
|
|
145
|
+
Create a DataLoader from various data sources.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
data: Data source, can be a dict, pd.DataFrame, file path (str), or existing DataLoader.
|
|
149
|
+
batch_size: Batch size for DataLoader.
|
|
150
|
+
shuffle: Whether to shuffle the data (ignored in streaming mode).
|
|
151
|
+
load_full: If True, load full data into memory; if False, use streaming mode for large files.
|
|
152
|
+
chunk_size: Chunk size for streaming mode (number of rows per chunk).
|
|
153
|
+
num_workers: Number of worker processes for data loading.
|
|
154
|
+
sampler: Optional sampler for DataLoader, only used for distributed training.
|
|
155
|
+
Returns:
|
|
156
|
+
DataLoader instance.
|
|
157
|
+
"""
|
|
158
|
+
|
|
131
159
|
if isinstance(data, DataLoader):
|
|
132
160
|
return data
|
|
133
161
|
elif isinstance(data, (str, os.PathLike)):
|
|
134
162
|
return self.create_from_path(path=data, batch_size=batch_size, shuffle=shuffle, load_full=load_full, chunk_size=chunk_size, num_workers=num_workers)
|
|
135
163
|
elif isinstance(data, (dict, pd.DataFrame)):
|
|
136
|
-
return self.create_from_memory(data=data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
|
|
164
|
+
return self.create_from_memory(data=data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, sampler=sampler)
|
|
137
165
|
else:
|
|
138
166
|
raise ValueError(f"[RecDataLoader Error] Unsupported data type: {type(data)}")
|
|
139
167
|
|
|
@@ -141,7 +169,9 @@ class RecDataLoader(FeatureSet):
|
|
|
141
169
|
data: dict | pd.DataFrame,
|
|
142
170
|
batch_size: int,
|
|
143
171
|
shuffle: bool,
|
|
144
|
-
num_workers: int = 0
|
|
172
|
+
num_workers: int = 0,
|
|
173
|
+
sampler=None) -> DataLoader:
|
|
174
|
+
|
|
145
175
|
raw_data = data
|
|
146
176
|
|
|
147
177
|
if self.processor is not None:
|
|
@@ -152,7 +182,7 @@ class RecDataLoader(FeatureSet):
|
|
|
152
182
|
if tensors is None:
|
|
153
183
|
raise ValueError("[RecDataLoader Error] No valid tensors could be built from the provided data.")
|
|
154
184
|
dataset = TensorDictDataset(tensors)
|
|
155
|
-
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn, num_workers=num_workers)
|
|
185
|
+
return DataLoader(dataset, batch_size=batch_size, shuffle=False if sampler is not None else shuffle, sampler=sampler, collate_fn=collate_fn, num_workers=num_workers)
|
|
156
186
|
|
|
157
187
|
def create_from_path(self,
|
|
158
188
|
path: str,
|
|
@@ -255,7 +255,7 @@ class HSTU(BaseModel):
|
|
|
255
255
|
return "HSTU"
|
|
256
256
|
|
|
257
257
|
@property
|
|
258
|
-
def
|
|
258
|
+
def default_task(self) -> str:
|
|
259
259
|
return "multiclass"
|
|
260
260
|
|
|
261
261
|
def __init__(
|
|
@@ -275,6 +275,7 @@ class HSTU(BaseModel):
|
|
|
275
275
|
|
|
276
276
|
tie_embeddings: bool = True,
|
|
277
277
|
target: Optional[list[str] | str] = None,
|
|
278
|
+
task: str | list[str] | None = None,
|
|
278
279
|
optimizer: str = "adam",
|
|
279
280
|
optimizer_params: Optional[dict] = None,
|
|
280
281
|
scheduler: Optional[str] = None,
|
|
@@ -307,7 +308,7 @@ class HSTU(BaseModel):
|
|
|
307
308
|
sparse_features=sparse_features,
|
|
308
309
|
sequence_features=sequence_features,
|
|
309
310
|
target=target,
|
|
310
|
-
task=self.
|
|
311
|
+
task=task or self.default_task,
|
|
311
312
|
device=device,
|
|
312
313
|
embedding_l1_reg=embedding_l1_reg,
|
|
313
314
|
dense_l1_reg=dense_l1_reg,
|
nextrec/models/match/dssm.py
CHANGED
nextrec/models/match/dssm_v2.py
CHANGED
nextrec/models/match/mind.py
CHANGED
nextrec/models/match/sdm.py
CHANGED
|
@@ -64,10 +64,9 @@ class ESMM(BaseModel):
|
|
|
64
64
|
@property
|
|
65
65
|
def model_name(self):
|
|
66
66
|
return "ESMM"
|
|
67
|
-
|
|
67
|
+
|
|
68
68
|
@property
|
|
69
|
-
def
|
|
70
|
-
# ESMM has fixed task types: CTR (binary) and CVR (binary)
|
|
69
|
+
def default_task(self):
|
|
71
70
|
return ['binary', 'binary']
|
|
72
71
|
|
|
73
72
|
def __init__(self,
|
|
@@ -77,7 +76,7 @@ class ESMM(BaseModel):
|
|
|
77
76
|
ctr_params: dict,
|
|
78
77
|
cvr_params: dict,
|
|
79
78
|
target: list[str] = ['ctr', 'ctcvr'], # Note: ctcvr = ctr * cvr
|
|
80
|
-
task: list[str] =
|
|
79
|
+
task: list[str] | None = None,
|
|
81
80
|
optimizer: str = "adam",
|
|
82
81
|
optimizer_params: dict = {},
|
|
83
82
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
@@ -98,13 +97,12 @@ class ESMM(BaseModel):
|
|
|
98
97
|
sparse_features=sparse_features,
|
|
99
98
|
sequence_features=sequence_features,
|
|
100
99
|
target=target,
|
|
101
|
-
task=task, # Both CTR and CTCVR are binary classification
|
|
100
|
+
task=task or self.default_task, # Both CTR and CTCVR are binary classification
|
|
102
101
|
device=device,
|
|
103
102
|
embedding_l1_reg=embedding_l1_reg,
|
|
104
103
|
dense_l1_reg=dense_l1_reg,
|
|
105
104
|
embedding_l2_reg=embedding_l2_reg,
|
|
106
105
|
dense_l2_reg=dense_l2_reg,
|
|
107
|
-
early_stop_patience=20,
|
|
108
106
|
**kwargs
|
|
109
107
|
)
|
|
110
108
|
|
|
@@ -126,7 +124,7 @@ class ESMM(BaseModel):
|
|
|
126
124
|
|
|
127
125
|
# CVR tower
|
|
128
126
|
self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
|
|
129
|
-
self.prediction_layer = PredictionLayer(task_type=self.
|
|
127
|
+
self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1, 1])
|
|
130
128
|
# Register regularization weights
|
|
131
129
|
self.register_regularization_weights(embedding_attr='embedding', include_modules=['ctr_tower', 'cvr_tower'])
|
|
132
130
|
self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
|
|
@@ -65,8 +65,11 @@ class MMOE(BaseModel):
|
|
|
65
65
|
return "MMOE"
|
|
66
66
|
|
|
67
67
|
@property
|
|
68
|
-
def
|
|
69
|
-
|
|
68
|
+
def default_task(self):
|
|
69
|
+
num_tasks = getattr(self, "num_tasks", None)
|
|
70
|
+
if num_tasks is not None and num_tasks > 0:
|
|
71
|
+
return ['binary'] * num_tasks
|
|
72
|
+
return ['binary']
|
|
70
73
|
|
|
71
74
|
def __init__(self,
|
|
72
75
|
dense_features: list[DenseFeature]=[],
|
|
@@ -76,7 +79,7 @@ class MMOE(BaseModel):
|
|
|
76
79
|
num_experts: int=3,
|
|
77
80
|
tower_params_list: list[dict]=[],
|
|
78
81
|
target: list[str]=[],
|
|
79
|
-
task: str | list[str] =
|
|
82
|
+
task: str | list[str] | None = None,
|
|
80
83
|
optimizer: str = "adam",
|
|
81
84
|
optimizer_params: dict = {},
|
|
82
85
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
@@ -88,18 +91,19 @@ class MMOE(BaseModel):
|
|
|
88
91
|
dense_l2_reg=1e-4,
|
|
89
92
|
**kwargs):
|
|
90
93
|
|
|
94
|
+
self.num_tasks = len(target)
|
|
95
|
+
|
|
91
96
|
super(MMOE, self).__init__(
|
|
92
97
|
dense_features=dense_features,
|
|
93
98
|
sparse_features=sparse_features,
|
|
94
99
|
sequence_features=sequence_features,
|
|
95
100
|
target=target,
|
|
96
|
-
task=task,
|
|
101
|
+
task=task or self.default_task,
|
|
97
102
|
device=device,
|
|
98
103
|
embedding_l1_reg=embedding_l1_reg,
|
|
99
104
|
dense_l1_reg=dense_l1_reg,
|
|
100
105
|
embedding_l2_reg=embedding_l2_reg,
|
|
101
106
|
dense_l2_reg=dense_l2_reg,
|
|
102
|
-
early_stop_patience=20,
|
|
103
107
|
**kwargs
|
|
104
108
|
)
|
|
105
109
|
|
|
@@ -144,7 +148,7 @@ class MMOE(BaseModel):
|
|
|
144
148
|
for tower_params in tower_params_list:
|
|
145
149
|
tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
|
|
146
150
|
self.towers.append(tower)
|
|
147
|
-
self.prediction_layer = PredictionLayer(task_type=self.
|
|
151
|
+
self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1] * self.num_tasks)
|
|
148
152
|
# Register regularization weights
|
|
149
153
|
self.register_regularization_weights(embedding_attr='embedding', include_modules=['experts', 'gates', 'towers'])
|
|
150
154
|
self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params,)
|
nextrec/models/multi_task/ple.py
CHANGED
|
@@ -159,8 +159,11 @@ class PLE(BaseModel):
|
|
|
159
159
|
return "PLE"
|
|
160
160
|
|
|
161
161
|
@property
|
|
162
|
-
def
|
|
163
|
-
|
|
162
|
+
def default_task(self):
|
|
163
|
+
num_tasks = getattr(self, "num_tasks", None)
|
|
164
|
+
if num_tasks is not None and num_tasks > 0:
|
|
165
|
+
return ['binary'] * num_tasks
|
|
166
|
+
return ['binary']
|
|
164
167
|
|
|
165
168
|
def __init__(self,
|
|
166
169
|
dense_features: list[DenseFeature],
|
|
@@ -173,7 +176,7 @@ class PLE(BaseModel):
|
|
|
173
176
|
num_levels: int,
|
|
174
177
|
tower_params_list: list[dict],
|
|
175
178
|
target: list[str],
|
|
176
|
-
task: str | list[str] =
|
|
179
|
+
task: str | list[str] | None = None,
|
|
177
180
|
optimizer: str = "adam",
|
|
178
181
|
optimizer_params: dict | None = None,
|
|
179
182
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
@@ -185,18 +188,19 @@ class PLE(BaseModel):
|
|
|
185
188
|
dense_l2_reg=1e-4,
|
|
186
189
|
**kwargs):
|
|
187
190
|
|
|
191
|
+
self.num_tasks = len(target)
|
|
192
|
+
|
|
188
193
|
super(PLE, self).__init__(
|
|
189
194
|
dense_features=dense_features,
|
|
190
195
|
sparse_features=sparse_features,
|
|
191
196
|
sequence_features=sequence_features,
|
|
192
197
|
target=target,
|
|
193
|
-
task=task,
|
|
198
|
+
task=task or self.default_task,
|
|
194
199
|
device=device,
|
|
195
200
|
embedding_l1_reg=embedding_l1_reg,
|
|
196
201
|
dense_l1_reg=dense_l1_reg,
|
|
197
202
|
embedding_l2_reg=embedding_l2_reg,
|
|
198
203
|
dense_l2_reg=dense_l2_reg,
|
|
199
|
-
early_stop_patience=20,
|
|
200
204
|
**kwargs
|
|
201
205
|
)
|
|
202
206
|
|
|
@@ -247,7 +251,7 @@ class PLE(BaseModel):
|
|
|
247
251
|
for tower_params in tower_params_list:
|
|
248
252
|
tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
|
|
249
253
|
self.towers.append(tower)
|
|
250
|
-
self.prediction_layer = PredictionLayer(task_type=self.
|
|
254
|
+
self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1] * self.num_tasks)
|
|
251
255
|
# Register regularization weights
|
|
252
256
|
self.register_regularization_weights(embedding_attr='embedding', include_modules=['cgc_layers', 'towers'])
|
|
253
257
|
self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=self.loss, loss_params=loss_params)
|
|
@@ -261,8 +261,11 @@ class POSO(BaseModel):
|
|
|
261
261
|
return "POSO"
|
|
262
262
|
|
|
263
263
|
@property
|
|
264
|
-
def
|
|
265
|
-
|
|
264
|
+
def default_task(self) -> list[str]:
|
|
265
|
+
num_tasks = getattr(self, "num_tasks", None)
|
|
266
|
+
if num_tasks is not None and num_tasks > 0:
|
|
267
|
+
return ["binary"] * num_tasks
|
|
268
|
+
return ["binary"]
|
|
266
269
|
|
|
267
270
|
def __init__(
|
|
268
271
|
self,
|
|
@@ -274,7 +277,7 @@ class POSO(BaseModel):
|
|
|
274
277
|
pc_sequence_features: list[SequenceFeature] | None,
|
|
275
278
|
tower_params_list: list[dict],
|
|
276
279
|
target: list[str],
|
|
277
|
-
task: str | list[str] =
|
|
280
|
+
task: str | list[str] | None = None,
|
|
278
281
|
architecture: str = "mlp",
|
|
279
282
|
# POSO gating defaults
|
|
280
283
|
gate_hidden_dim: int = 32,
|
|
@@ -307,6 +310,7 @@ class POSO(BaseModel):
|
|
|
307
310
|
self.pc_dense_features = list(pc_dense_features or [])
|
|
308
311
|
self.pc_sparse_features = list(pc_sparse_features or [])
|
|
309
312
|
self.pc_sequence_features = list(pc_sequence_features or [])
|
|
313
|
+
self.num_tasks = len(target)
|
|
310
314
|
|
|
311
315
|
if not self.pc_dense_features and not self.pc_sparse_features and not self.pc_sequence_features:
|
|
312
316
|
raise ValueError("POSO requires at least one PC feature for personalization.")
|
|
@@ -320,13 +324,12 @@ class POSO(BaseModel):
|
|
|
320
324
|
sparse_features=sparse_features,
|
|
321
325
|
sequence_features=sequence_features,
|
|
322
326
|
target=target,
|
|
323
|
-
task=task,
|
|
327
|
+
task=task or self.default_task,
|
|
324
328
|
device=device,
|
|
325
329
|
embedding_l1_reg=embedding_l1_reg,
|
|
326
330
|
dense_l1_reg=dense_l1_reg,
|
|
327
331
|
embedding_l2_reg=embedding_l2_reg,
|
|
328
332
|
dense_l2_reg=dense_l2_reg,
|
|
329
|
-
early_stop_patience=20,
|
|
330
333
|
**kwargs,
|
|
331
334
|
)
|
|
332
335
|
|
|
@@ -387,7 +390,7 @@ class POSO(BaseModel):
|
|
|
387
390
|
)
|
|
388
391
|
self.towers = nn.ModuleList([MLP(input_dim=self.mmoe.expert_output_dim, output_layer=True, **tower_params,) for tower_params in tower_params_list])
|
|
389
392
|
self.tower_heads = None
|
|
390
|
-
self.prediction_layer = PredictionLayer(task_type=self.
|
|
393
|
+
self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1] * self.num_tasks,)
|
|
391
394
|
include_modules = ["towers", "tower_heads"] if self.architecture == "mlp" else ["mmoe", "towers"]
|
|
392
395
|
self.register_regularization_weights(embedding_attr="embedding", include_modules=include_modules)
|
|
393
396
|
self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
|
|
@@ -53,9 +53,11 @@ class ShareBottom(BaseModel):
|
|
|
53
53
|
return "ShareBottom"
|
|
54
54
|
|
|
55
55
|
@property
|
|
56
|
-
def
|
|
57
|
-
|
|
58
|
-
|
|
56
|
+
def default_task(self):
|
|
57
|
+
num_tasks = getattr(self, "num_tasks", None)
|
|
58
|
+
if num_tasks is not None and num_tasks > 0:
|
|
59
|
+
return ['binary'] * num_tasks
|
|
60
|
+
return ['binary']
|
|
59
61
|
|
|
60
62
|
def __init__(self,
|
|
61
63
|
dense_features: list[DenseFeature],
|
|
@@ -64,7 +66,7 @@ class ShareBottom(BaseModel):
|
|
|
64
66
|
bottom_params: dict,
|
|
65
67
|
tower_params_list: list[dict],
|
|
66
68
|
target: list[str],
|
|
67
|
-
task: str | list[str] =
|
|
69
|
+
task: str | list[str] | None = None,
|
|
68
70
|
optimizer: str = "adam",
|
|
69
71
|
optimizer_params: dict = {},
|
|
70
72
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
@@ -76,18 +78,19 @@ class ShareBottom(BaseModel):
|
|
|
76
78
|
dense_l2_reg=1e-4,
|
|
77
79
|
**kwargs):
|
|
78
80
|
|
|
81
|
+
self.num_tasks = len(target)
|
|
82
|
+
|
|
79
83
|
super(ShareBottom, self).__init__(
|
|
80
84
|
dense_features=dense_features,
|
|
81
85
|
sparse_features=sparse_features,
|
|
82
86
|
sequence_features=sequence_features,
|
|
83
87
|
target=target,
|
|
84
|
-
task=task,
|
|
88
|
+
task=task or self.default_task,
|
|
85
89
|
device=device,
|
|
86
90
|
embedding_l1_reg=embedding_l1_reg,
|
|
87
91
|
dense_l1_reg=dense_l1_reg,
|
|
88
92
|
embedding_l2_reg=embedding_l2_reg,
|
|
89
93
|
dense_l2_reg=dense_l2_reg,
|
|
90
|
-
early_stop_patience=20,
|
|
91
94
|
**kwargs
|
|
92
95
|
)
|
|
93
96
|
|
|
@@ -120,7 +123,7 @@ class ShareBottom(BaseModel):
|
|
|
120
123
|
for tower_params in tower_params_list:
|
|
121
124
|
tower = MLP(input_dim=bottom_output_dim, output_layer=True, **tower_params)
|
|
122
125
|
self.towers.append(tower)
|
|
123
|
-
self.prediction_layer = PredictionLayer(task_type=self.
|
|
126
|
+
self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1] * self.num_tasks)
|
|
124
127
|
# Register regularization weights
|
|
125
128
|
self.register_regularization_weights(embedding_attr='embedding', include_modules=['bottom', 'towers'])
|
|
126
129
|
self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
|