nextrec 0.3.5__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +0 -30
- nextrec/__version__.py +1 -1
- nextrec/basic/layers.py +32 -15
- nextrec/basic/loggers.py +1 -1
- nextrec/basic/model.py +440 -189
- nextrec/basic/session.py +4 -2
- nextrec/data/__init__.py +0 -25
- nextrec/data/data_processing.py +31 -19
- nextrec/data/dataloader.py +51 -16
- nextrec/models/generative/__init__.py +0 -5
- nextrec/models/generative/hstu.py +3 -2
- nextrec/models/match/__init__.py +0 -13
- nextrec/models/match/dssm.py +0 -1
- nextrec/models/match/dssm_v2.py +0 -1
- nextrec/models/match/mind.py +0 -1
- nextrec/models/match/sdm.py +0 -1
- nextrec/models/match/youtube_dnn.py +0 -1
- nextrec/models/multi_task/__init__.py +0 -0
- nextrec/models/multi_task/esmm.py +5 -7
- nextrec/models/multi_task/mmoe.py +10 -6
- nextrec/models/multi_task/ple.py +10 -6
- nextrec/models/multi_task/poso.py +9 -6
- nextrec/models/multi_task/share_bottom.py +10 -7
- nextrec/models/ranking/__init__.py +0 -27
- nextrec/models/ranking/afm.py +113 -21
- nextrec/models/ranking/autoint.py +15 -9
- nextrec/models/ranking/dcn.py +8 -11
- nextrec/models/ranking/deepfm.py +5 -5
- nextrec/models/ranking/dien.py +4 -4
- nextrec/models/ranking/din.py +4 -4
- nextrec/models/ranking/fibinet.py +4 -4
- nextrec/models/ranking/fm.py +4 -4
- nextrec/models/ranking/masknet.py +4 -5
- nextrec/models/ranking/pnn.py +4 -4
- nextrec/models/ranking/widedeep.py +4 -4
- nextrec/models/ranking/xdeepfm.py +4 -4
- nextrec/utils/__init__.py +7 -3
- nextrec/utils/device.py +32 -1
- nextrec/utils/distributed.py +114 -0
- nextrec/utils/synthetic_data.py +413 -0
- {nextrec-0.3.5.dist-info → nextrec-0.4.1.dist-info}/METADATA +15 -5
- nextrec-0.4.1.dist-info/RECORD +66 -0
- nextrec-0.3.5.dist-info/RECORD +0 -63
- {nextrec-0.3.5.dist-info → nextrec-0.4.1.dist-info}/WHEEL +0 -0
- {nextrec-0.3.5.dist-info → nextrec-0.4.1.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/session.py
CHANGED
|
@@ -22,6 +22,7 @@ class Session:
|
|
|
22
22
|
|
|
23
23
|
experiment_id: str
|
|
24
24
|
root: Path
|
|
25
|
+
log_basename: str # The base name for log files, without path separators
|
|
25
26
|
|
|
26
27
|
@property
|
|
27
28
|
def logs_dir(self) -> Path:
|
|
@@ -60,7 +61,6 @@ class Session:
|
|
|
60
61
|
return path
|
|
61
62
|
|
|
62
63
|
def create_session(experiment_id: str | Path | None = None) -> Session:
|
|
63
|
-
"""Create a :class:`Session` instance with prepared directories."""
|
|
64
64
|
|
|
65
65
|
if experiment_id is not None and str(experiment_id).strip():
|
|
66
66
|
exp_id = str(experiment_id).strip()
|
|
@@ -68,6 +68,8 @@ def create_session(experiment_id: str | Path | None = None) -> Session:
|
|
|
68
68
|
# Use local time for session naming
|
|
69
69
|
exp_id = "nextrec_session_" + datetime.now().strftime("%Y%m%d")
|
|
70
70
|
|
|
71
|
+
log_basename = Path(exp_id).name if exp_id else exp_id
|
|
72
|
+
|
|
71
73
|
if (
|
|
72
74
|
os.getenv("PYTEST_CURRENT_TEST")
|
|
73
75
|
or os.getenv("PYTEST_RUNNING")
|
|
@@ -82,7 +84,7 @@ def create_session(experiment_id: str | Path | None = None) -> Session:
|
|
|
82
84
|
session_path.mkdir(parents=True, exist_ok=True)
|
|
83
85
|
root = session_path.resolve()
|
|
84
86
|
|
|
85
|
-
return Session(experiment_id=exp_id, root=root)
|
|
87
|
+
return Session(experiment_id=exp_id, root=root, log_basename=log_basename)
|
|
86
88
|
|
|
87
89
|
def resolve_save_path(
|
|
88
90
|
path: str | os.PathLike | Path | None,
|
nextrec/data/__init__.py
CHANGED
|
@@ -1,22 +1,4 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Data utilities package for NextRec
|
|
3
|
-
|
|
4
|
-
This package provides data processing and manipulation utilities organized by category:
|
|
5
|
-
- batch_utils: Batch collation and processing
|
|
6
|
-
- data_processing: Data manipulation and user ID extraction
|
|
7
|
-
- data_utils: Legacy module (re-exports from specialized modules)
|
|
8
|
-
- dataloader: Dataset and DataLoader implementations
|
|
9
|
-
- preprocessor: Data preprocessing pipeline
|
|
10
|
-
|
|
11
|
-
Date: create on 13/11/2025
|
|
12
|
-
Last update: 03/12/2025 (refactored)
|
|
13
|
-
Author: Yang Zhou, zyaztec@gmail.com
|
|
14
|
-
"""
|
|
15
|
-
|
|
16
|
-
# Batch utilities
|
|
17
1
|
from nextrec.data.batch_utils import collate_fn, batch_to_dict, stack_section
|
|
18
|
-
|
|
19
|
-
# Data processing utilities
|
|
20
2
|
from nextrec.data.data_processing import (
|
|
21
3
|
get_column_data,
|
|
22
4
|
split_dict_random,
|
|
@@ -24,7 +6,6 @@ from nextrec.data.data_processing import (
|
|
|
24
6
|
get_user_ids,
|
|
25
7
|
)
|
|
26
8
|
|
|
27
|
-
# File utilities (from utils package)
|
|
28
9
|
from nextrec.utils.file import (
|
|
29
10
|
resolve_file_paths,
|
|
30
11
|
iter_file_chunks,
|
|
@@ -33,7 +14,6 @@ from nextrec.utils.file import (
|
|
|
33
14
|
default_output_dir,
|
|
34
15
|
)
|
|
35
16
|
|
|
36
|
-
# DataLoader components
|
|
37
17
|
from nextrec.data.dataloader import (
|
|
38
18
|
TensorDictDataset,
|
|
39
19
|
FileDataset,
|
|
@@ -41,13 +21,8 @@ from nextrec.data.dataloader import (
|
|
|
41
21
|
build_tensors_from_data,
|
|
42
22
|
)
|
|
43
23
|
|
|
44
|
-
# Preprocessor
|
|
45
24
|
from nextrec.data.preprocessor import DataProcessor
|
|
46
|
-
|
|
47
|
-
# Feature definitions
|
|
48
25
|
from nextrec.basic.features import FeatureSet
|
|
49
|
-
|
|
50
|
-
# Legacy module (for backward compatibility)
|
|
51
26
|
from nextrec.data import data_utils
|
|
52
27
|
|
|
53
28
|
__all__ = [
|
nextrec/data/data_processing.py
CHANGED
|
@@ -11,7 +11,10 @@ import pandas as pd
|
|
|
11
11
|
from typing import Any, Mapping
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
def get_column_data(
|
|
14
|
+
def get_column_data(
|
|
15
|
+
data: dict | pd.DataFrame,
|
|
16
|
+
name: str):
|
|
17
|
+
|
|
15
18
|
if isinstance(data, dict):
|
|
16
19
|
return data[name] if name in data else None
|
|
17
20
|
elif isinstance(data, pd.DataFrame):
|
|
@@ -24,10 +27,11 @@ def get_column_data(data: dict | pd.DataFrame, name: str):
|
|
|
24
27
|
raise KeyError(f"Unsupported data type for extracting column {name}")
|
|
25
28
|
|
|
26
29
|
def split_dict_random(
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
):
|
|
30
|
+
data_dict: dict,
|
|
31
|
+
test_size: float = 0.2,
|
|
32
|
+
random_state: int | None = None
|
|
33
|
+
):
|
|
34
|
+
|
|
31
35
|
lengths = [len(v) for v in data_dict.values()]
|
|
32
36
|
if len(set(lengths)) != 1:
|
|
33
37
|
raise ValueError(f"Length mismatch: {lengths}")
|
|
@@ -51,18 +55,27 @@ def split_dict_random(
|
|
|
51
55
|
test_dict = {k: take(v, test_idx) for k, v in data_dict.items()}
|
|
52
56
|
return train_dict, test_dict
|
|
53
57
|
|
|
58
|
+
def split_data(
|
|
59
|
+
df: pd.DataFrame,
|
|
60
|
+
test_size: float = 0.2
|
|
61
|
+
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
62
|
+
|
|
63
|
+
split_idx = int(len(df) * (1 - test_size))
|
|
64
|
+
train_df = df.iloc[:split_idx].reset_index(drop=True)
|
|
65
|
+
valid_df = df.iloc[split_idx:].reset_index(drop=True)
|
|
66
|
+
return train_df, valid_df
|
|
54
67
|
|
|
55
68
|
def build_eval_candidates(
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
) -> pd.DataFrame:
|
|
69
|
+
df_all: pd.DataFrame,
|
|
70
|
+
user_col: str,
|
|
71
|
+
item_col: str,
|
|
72
|
+
label_col: str,
|
|
73
|
+
user_features: pd.DataFrame,
|
|
74
|
+
item_features: pd.DataFrame,
|
|
75
|
+
num_pos_per_user: int = 5,
|
|
76
|
+
num_neg_per_pos: int = 50,
|
|
77
|
+
random_seed: int = 2025,
|
|
78
|
+
) -> pd.DataFrame:
|
|
66
79
|
"""
|
|
67
80
|
Build evaluation candidates with positive and negative samples for each user.
|
|
68
81
|
|
|
@@ -111,11 +124,10 @@ def build_eval_candidates(
|
|
|
111
124
|
eval_df = eval_df.merge(item_features, on=item_col, how='left')
|
|
112
125
|
return eval_df
|
|
113
126
|
|
|
114
|
-
|
|
115
127
|
def get_user_ids(
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
) -> np.ndarray | None:
|
|
128
|
+
data: Any,
|
|
129
|
+
id_columns: list[str] | str | None = None
|
|
130
|
+
) -> np.ndarray | None:
|
|
119
131
|
"""
|
|
120
132
|
Extract user IDs from various data structures.
|
|
121
133
|
|
nextrec/data/dataloader.py
CHANGED
|
@@ -15,15 +15,15 @@ import pyarrow.parquet as pq
|
|
|
15
15
|
from pathlib import Path
|
|
16
16
|
from typing import cast
|
|
17
17
|
|
|
18
|
-
from
|
|
19
|
-
from nextrec.data.preprocessor import DataProcessor
|
|
18
|
+
from nextrec.basic.loggers import colorize
|
|
20
19
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSet
|
|
20
|
+
from nextrec.data.preprocessor import DataProcessor
|
|
21
|
+
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
|
21
22
|
|
|
22
|
-
from nextrec.basic.loggers import colorize
|
|
23
|
-
from nextrec.data.data_processing import get_column_data
|
|
24
|
-
from nextrec.data.batch_utils import collate_fn
|
|
25
|
-
from nextrec.utils.file import resolve_file_paths, read_table
|
|
26
23
|
from nextrec.utils.tensor import to_tensor
|
|
24
|
+
from nextrec.utils.file import resolve_file_paths, read_table
|
|
25
|
+
from nextrec.data.batch_utils import collate_fn
|
|
26
|
+
from nextrec.data.data_processing import get_column_data
|
|
27
27
|
|
|
28
28
|
class TensorDictDataset(Dataset):
|
|
29
29
|
"""Dataset returning sample-level dicts matching the unified batch schema."""
|
|
@@ -118,6 +118,18 @@ class RecDataLoader(FeatureSet):
|
|
|
118
118
|
target: list[str] | None | str = None,
|
|
119
119
|
id_columns: str | list[str] | None = None,
|
|
120
120
|
processor: DataProcessor | None = None):
|
|
121
|
+
"""
|
|
122
|
+
RecDataLoader is a unified dataloader for supporting in-memory and streaming data.
|
|
123
|
+
Basemodel will accept RecDataLoader to create dataloaders for training/evaluation/prediction.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
dense_features: list of DenseFeature definitions
|
|
127
|
+
sparse_features: list of SparseFeature definitions
|
|
128
|
+
sequence_features: list of SequenceFeature definitions
|
|
129
|
+
target: target column name(s), e.g. 'label' or ['ctr', 'ctcvr']
|
|
130
|
+
id_columns: id column name(s) to carry through (not used for model inputs), e.g. 'user_id' or ['user_id', 'item_id']
|
|
131
|
+
processor: an instance of DataProcessor, if provided, will be used to transform data before creating tensors.
|
|
132
|
+
"""
|
|
121
133
|
self.processor = processor
|
|
122
134
|
self.set_all_features(dense_features, sparse_features, sequence_features, target, id_columns)
|
|
123
135
|
|
|
@@ -126,20 +138,40 @@ class RecDataLoader(FeatureSet):
|
|
|
126
138
|
batch_size: int = 32,
|
|
127
139
|
shuffle: bool = True,
|
|
128
140
|
load_full: bool = True,
|
|
129
|
-
chunk_size: int = 10000
|
|
141
|
+
chunk_size: int = 10000,
|
|
142
|
+
num_workers: int = 0,
|
|
143
|
+
sampler = None) -> DataLoader:
|
|
144
|
+
"""
|
|
145
|
+
Create a DataLoader from various data sources.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
data: Data source, can be a dict, pd.DataFrame, file path (str), or existing DataLoader.
|
|
149
|
+
batch_size: Batch size for DataLoader.
|
|
150
|
+
shuffle: Whether to shuffle the data (ignored in streaming mode).
|
|
151
|
+
load_full: If True, load full data into memory; if False, use streaming mode for large files.
|
|
152
|
+
chunk_size: Chunk size for streaming mode (number of rows per chunk).
|
|
153
|
+
num_workers: Number of worker processes for data loading.
|
|
154
|
+
sampler: Optional sampler for DataLoader, only used for distributed training.
|
|
155
|
+
Returns:
|
|
156
|
+
DataLoader instance.
|
|
157
|
+
"""
|
|
158
|
+
|
|
130
159
|
if isinstance(data, DataLoader):
|
|
131
160
|
return data
|
|
132
161
|
elif isinstance(data, (str, os.PathLike)):
|
|
133
|
-
return self.create_from_path(path=data, batch_size=batch_size, shuffle=shuffle, load_full=load_full, chunk_size=chunk_size)
|
|
162
|
+
return self.create_from_path(path=data, batch_size=batch_size, shuffle=shuffle, load_full=load_full, chunk_size=chunk_size, num_workers=num_workers)
|
|
134
163
|
elif isinstance(data, (dict, pd.DataFrame)):
|
|
135
|
-
return self.create_from_memory(data=data, batch_size=batch_size, shuffle=shuffle)
|
|
164
|
+
return self.create_from_memory(data=data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, sampler=sampler)
|
|
136
165
|
else:
|
|
137
166
|
raise ValueError(f"[RecDataLoader Error] Unsupported data type: {type(data)}")
|
|
138
167
|
|
|
139
168
|
def create_from_memory(self,
|
|
140
169
|
data: dict | pd.DataFrame,
|
|
141
170
|
batch_size: int,
|
|
142
|
-
shuffle: bool
|
|
171
|
+
shuffle: bool,
|
|
172
|
+
num_workers: int = 0,
|
|
173
|
+
sampler=None) -> DataLoader:
|
|
174
|
+
|
|
143
175
|
raw_data = data
|
|
144
176
|
|
|
145
177
|
if self.processor is not None:
|
|
@@ -150,14 +182,15 @@ class RecDataLoader(FeatureSet):
|
|
|
150
182
|
if tensors is None:
|
|
151
183
|
raise ValueError("[RecDataLoader Error] No valid tensors could be built from the provided data.")
|
|
152
184
|
dataset = TensorDictDataset(tensors)
|
|
153
|
-
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
|
|
185
|
+
return DataLoader(dataset, batch_size=batch_size, shuffle=False if sampler is not None else shuffle, sampler=sampler, collate_fn=collate_fn, num_workers=num_workers)
|
|
154
186
|
|
|
155
187
|
def create_from_path(self,
|
|
156
188
|
path: str,
|
|
157
189
|
batch_size: int,
|
|
158
190
|
shuffle: bool,
|
|
159
191
|
load_full: bool,
|
|
160
|
-
chunk_size: int = 10000
|
|
192
|
+
chunk_size: int = 10000,
|
|
193
|
+
num_workers: int = 0) -> DataLoader:
|
|
161
194
|
file_paths, file_type = resolve_file_paths(str(Path(path)))
|
|
162
195
|
# Load full data into memory
|
|
163
196
|
if load_full:
|
|
@@ -169,6 +202,7 @@ class RecDataLoader(FeatureSet):
|
|
|
169
202
|
except OSError:
|
|
170
203
|
pass
|
|
171
204
|
try:
|
|
205
|
+
df = read_table(file_path, file_type=file_type)
|
|
172
206
|
dfs.append(df)
|
|
173
207
|
except MemoryError as exc:
|
|
174
208
|
raise MemoryError(f"[RecDataLoader Error] Out of memory while reading {file_path}. Consider using load_full=False with streaming.") from exc
|
|
@@ -176,22 +210,23 @@ class RecDataLoader(FeatureSet):
|
|
|
176
210
|
combined_df = pd.concat(dfs, ignore_index=True)
|
|
177
211
|
except MemoryError as exc:
|
|
178
212
|
raise MemoryError(f"[RecDataLoader Error] Out of memory while concatenating loaded data (approx {total_bytes / (1024**3):.2f} GB). Use load_full=False to stream or reduce chunk_size.") from exc
|
|
179
|
-
return self.create_from_memory(combined_df, batch_size, shuffle,)
|
|
213
|
+
return self.create_from_memory(combined_df, batch_size, shuffle, num_workers=num_workers)
|
|
180
214
|
else:
|
|
181
|
-
return self.load_files_streaming(file_paths, file_type, batch_size, chunk_size, shuffle)
|
|
215
|
+
return self.load_files_streaming(file_paths, file_type, batch_size, chunk_size, shuffle, num_workers=num_workers)
|
|
182
216
|
|
|
183
217
|
def load_files_streaming(self,
|
|
184
218
|
file_paths: list[str],
|
|
185
219
|
file_type: str,
|
|
186
220
|
batch_size: int,
|
|
187
221
|
chunk_size: int,
|
|
188
|
-
shuffle: bool
|
|
222
|
+
shuffle: bool,
|
|
223
|
+
num_workers: int = 0) -> DataLoader:
|
|
189
224
|
if shuffle:
|
|
190
225
|
logging.info("[RecDataLoader Info] Shuffle is ignored in streaming mode (IterableDataset).")
|
|
191
226
|
if batch_size != 1:
|
|
192
227
|
logging.info("[RecDataLoader Info] Streaming mode enforces batch_size=1; tune chunk_size to control memory/throughput.")
|
|
193
228
|
dataset = FileDataset(file_paths=file_paths, dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target_columns=self.target_columns, id_columns=self.id_columns, chunk_size=chunk_size, file_type=file_type, processor=self.processor)
|
|
194
|
-
return DataLoader(dataset, batch_size=1, collate_fn=collate_fn)
|
|
229
|
+
return DataLoader(dataset, batch_size=1, collate_fn=collate_fn, num_workers=num_workers)
|
|
195
230
|
|
|
196
231
|
def normalize_sequence_column(column, feature: SequenceFeature) -> np.ndarray:
|
|
197
232
|
if isinstance(column, pd.Series):
|
|
@@ -255,7 +255,7 @@ class HSTU(BaseModel):
|
|
|
255
255
|
return "HSTU"
|
|
256
256
|
|
|
257
257
|
@property
|
|
258
|
-
def
|
|
258
|
+
def default_task(self) -> str:
|
|
259
259
|
return "multiclass"
|
|
260
260
|
|
|
261
261
|
def __init__(
|
|
@@ -275,6 +275,7 @@ class HSTU(BaseModel):
|
|
|
275
275
|
|
|
276
276
|
tie_embeddings: bool = True,
|
|
277
277
|
target: Optional[list[str] | str] = None,
|
|
278
|
+
task: str | list[str] | None = None,
|
|
278
279
|
optimizer: str = "adam",
|
|
279
280
|
optimizer_params: Optional[dict] = None,
|
|
280
281
|
scheduler: Optional[str] = None,
|
|
@@ -307,7 +308,7 @@ class HSTU(BaseModel):
|
|
|
307
308
|
sparse_features=sparse_features,
|
|
308
309
|
sequence_features=sequence_features,
|
|
309
310
|
target=target,
|
|
310
|
-
task=self.
|
|
311
|
+
task=task or self.default_task,
|
|
311
312
|
device=device,
|
|
312
313
|
embedding_l1_reg=embedding_l1_reg,
|
|
313
314
|
dense_l1_reg=dense_l1_reg,
|
nextrec/models/match/__init__.py
CHANGED
nextrec/models/match/dssm.py
CHANGED
nextrec/models/match/dssm_v2.py
CHANGED
nextrec/models/match/mind.py
CHANGED
nextrec/models/match/sdm.py
CHANGED
|
File without changes
|
|
@@ -64,10 +64,9 @@ class ESMM(BaseModel):
|
|
|
64
64
|
@property
|
|
65
65
|
def model_name(self):
|
|
66
66
|
return "ESMM"
|
|
67
|
-
|
|
67
|
+
|
|
68
68
|
@property
|
|
69
|
-
def
|
|
70
|
-
# ESMM has fixed task types: CTR (binary) and CVR (binary)
|
|
69
|
+
def default_task(self):
|
|
71
70
|
return ['binary', 'binary']
|
|
72
71
|
|
|
73
72
|
def __init__(self,
|
|
@@ -77,7 +76,7 @@ class ESMM(BaseModel):
|
|
|
77
76
|
ctr_params: dict,
|
|
78
77
|
cvr_params: dict,
|
|
79
78
|
target: list[str] = ['ctr', 'ctcvr'], # Note: ctcvr = ctr * cvr
|
|
80
|
-
task: list[str] =
|
|
79
|
+
task: list[str] | None = None,
|
|
81
80
|
optimizer: str = "adam",
|
|
82
81
|
optimizer_params: dict = {},
|
|
83
82
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
@@ -98,13 +97,12 @@ class ESMM(BaseModel):
|
|
|
98
97
|
sparse_features=sparse_features,
|
|
99
98
|
sequence_features=sequence_features,
|
|
100
99
|
target=target,
|
|
101
|
-
task=task, # Both CTR and CTCVR are binary classification
|
|
100
|
+
task=task or self.default_task, # Both CTR and CTCVR are binary classification
|
|
102
101
|
device=device,
|
|
103
102
|
embedding_l1_reg=embedding_l1_reg,
|
|
104
103
|
dense_l1_reg=dense_l1_reg,
|
|
105
104
|
embedding_l2_reg=embedding_l2_reg,
|
|
106
105
|
dense_l2_reg=dense_l2_reg,
|
|
107
|
-
early_stop_patience=20,
|
|
108
106
|
**kwargs
|
|
109
107
|
)
|
|
110
108
|
|
|
@@ -126,7 +124,7 @@ class ESMM(BaseModel):
|
|
|
126
124
|
|
|
127
125
|
# CVR tower
|
|
128
126
|
self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
|
|
129
|
-
self.prediction_layer = PredictionLayer(task_type=self.
|
|
127
|
+
self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1, 1])
|
|
130
128
|
# Register regularization weights
|
|
131
129
|
self.register_regularization_weights(embedding_attr='embedding', include_modules=['ctr_tower', 'cvr_tower'])
|
|
132
130
|
self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
|
|
@@ -65,8 +65,11 @@ class MMOE(BaseModel):
|
|
|
65
65
|
return "MMOE"
|
|
66
66
|
|
|
67
67
|
@property
|
|
68
|
-
def
|
|
69
|
-
|
|
68
|
+
def default_task(self):
|
|
69
|
+
num_tasks = getattr(self, "num_tasks", None)
|
|
70
|
+
if num_tasks is not None and num_tasks > 0:
|
|
71
|
+
return ['binary'] * num_tasks
|
|
72
|
+
return ['binary']
|
|
70
73
|
|
|
71
74
|
def __init__(self,
|
|
72
75
|
dense_features: list[DenseFeature]=[],
|
|
@@ -76,7 +79,7 @@ class MMOE(BaseModel):
|
|
|
76
79
|
num_experts: int=3,
|
|
77
80
|
tower_params_list: list[dict]=[],
|
|
78
81
|
target: list[str]=[],
|
|
79
|
-
task: str | list[str] =
|
|
82
|
+
task: str | list[str] | None = None,
|
|
80
83
|
optimizer: str = "adam",
|
|
81
84
|
optimizer_params: dict = {},
|
|
82
85
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
@@ -88,18 +91,19 @@ class MMOE(BaseModel):
|
|
|
88
91
|
dense_l2_reg=1e-4,
|
|
89
92
|
**kwargs):
|
|
90
93
|
|
|
94
|
+
self.num_tasks = len(target)
|
|
95
|
+
|
|
91
96
|
super(MMOE, self).__init__(
|
|
92
97
|
dense_features=dense_features,
|
|
93
98
|
sparse_features=sparse_features,
|
|
94
99
|
sequence_features=sequence_features,
|
|
95
100
|
target=target,
|
|
96
|
-
task=task,
|
|
101
|
+
task=task or self.default_task,
|
|
97
102
|
device=device,
|
|
98
103
|
embedding_l1_reg=embedding_l1_reg,
|
|
99
104
|
dense_l1_reg=dense_l1_reg,
|
|
100
105
|
embedding_l2_reg=embedding_l2_reg,
|
|
101
106
|
dense_l2_reg=dense_l2_reg,
|
|
102
|
-
early_stop_patience=20,
|
|
103
107
|
**kwargs
|
|
104
108
|
)
|
|
105
109
|
|
|
@@ -144,7 +148,7 @@ class MMOE(BaseModel):
|
|
|
144
148
|
for tower_params in tower_params_list:
|
|
145
149
|
tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
|
|
146
150
|
self.towers.append(tower)
|
|
147
|
-
self.prediction_layer = PredictionLayer(task_type=self.
|
|
151
|
+
self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1] * self.num_tasks)
|
|
148
152
|
# Register regularization weights
|
|
149
153
|
self.register_regularization_weights(embedding_attr='embedding', include_modules=['experts', 'gates', 'towers'])
|
|
150
154
|
self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params,)
|
nextrec/models/multi_task/ple.py
CHANGED
|
@@ -159,8 +159,11 @@ class PLE(BaseModel):
|
|
|
159
159
|
return "PLE"
|
|
160
160
|
|
|
161
161
|
@property
|
|
162
|
-
def
|
|
163
|
-
|
|
162
|
+
def default_task(self):
|
|
163
|
+
num_tasks = getattr(self, "num_tasks", None)
|
|
164
|
+
if num_tasks is not None and num_tasks > 0:
|
|
165
|
+
return ['binary'] * num_tasks
|
|
166
|
+
return ['binary']
|
|
164
167
|
|
|
165
168
|
def __init__(self,
|
|
166
169
|
dense_features: list[DenseFeature],
|
|
@@ -173,7 +176,7 @@ class PLE(BaseModel):
|
|
|
173
176
|
num_levels: int,
|
|
174
177
|
tower_params_list: list[dict],
|
|
175
178
|
target: list[str],
|
|
176
|
-
task: str | list[str] =
|
|
179
|
+
task: str | list[str] | None = None,
|
|
177
180
|
optimizer: str = "adam",
|
|
178
181
|
optimizer_params: dict | None = None,
|
|
179
182
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
@@ -185,18 +188,19 @@ class PLE(BaseModel):
|
|
|
185
188
|
dense_l2_reg=1e-4,
|
|
186
189
|
**kwargs):
|
|
187
190
|
|
|
191
|
+
self.num_tasks = len(target)
|
|
192
|
+
|
|
188
193
|
super(PLE, self).__init__(
|
|
189
194
|
dense_features=dense_features,
|
|
190
195
|
sparse_features=sparse_features,
|
|
191
196
|
sequence_features=sequence_features,
|
|
192
197
|
target=target,
|
|
193
|
-
task=task,
|
|
198
|
+
task=task or self.default_task,
|
|
194
199
|
device=device,
|
|
195
200
|
embedding_l1_reg=embedding_l1_reg,
|
|
196
201
|
dense_l1_reg=dense_l1_reg,
|
|
197
202
|
embedding_l2_reg=embedding_l2_reg,
|
|
198
203
|
dense_l2_reg=dense_l2_reg,
|
|
199
|
-
early_stop_patience=20,
|
|
200
204
|
**kwargs
|
|
201
205
|
)
|
|
202
206
|
|
|
@@ -247,7 +251,7 @@ class PLE(BaseModel):
|
|
|
247
251
|
for tower_params in tower_params_list:
|
|
248
252
|
tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
|
|
249
253
|
self.towers.append(tower)
|
|
250
|
-
self.prediction_layer = PredictionLayer(task_type=self.
|
|
254
|
+
self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1] * self.num_tasks)
|
|
251
255
|
# Register regularization weights
|
|
252
256
|
self.register_regularization_weights(embedding_attr='embedding', include_modules=['cgc_layers', 'towers'])
|
|
253
257
|
self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=self.loss, loss_params=loss_params)
|
|
@@ -261,8 +261,11 @@ class POSO(BaseModel):
|
|
|
261
261
|
return "POSO"
|
|
262
262
|
|
|
263
263
|
@property
|
|
264
|
-
def
|
|
265
|
-
|
|
264
|
+
def default_task(self) -> list[str]:
|
|
265
|
+
num_tasks = getattr(self, "num_tasks", None)
|
|
266
|
+
if num_tasks is not None and num_tasks > 0:
|
|
267
|
+
return ["binary"] * num_tasks
|
|
268
|
+
return ["binary"]
|
|
266
269
|
|
|
267
270
|
def __init__(
|
|
268
271
|
self,
|
|
@@ -274,7 +277,7 @@ class POSO(BaseModel):
|
|
|
274
277
|
pc_sequence_features: list[SequenceFeature] | None,
|
|
275
278
|
tower_params_list: list[dict],
|
|
276
279
|
target: list[str],
|
|
277
|
-
task: str | list[str] =
|
|
280
|
+
task: str | list[str] | None = None,
|
|
278
281
|
architecture: str = "mlp",
|
|
279
282
|
# POSO gating defaults
|
|
280
283
|
gate_hidden_dim: int = 32,
|
|
@@ -307,6 +310,7 @@ class POSO(BaseModel):
|
|
|
307
310
|
self.pc_dense_features = list(pc_dense_features or [])
|
|
308
311
|
self.pc_sparse_features = list(pc_sparse_features or [])
|
|
309
312
|
self.pc_sequence_features = list(pc_sequence_features or [])
|
|
313
|
+
self.num_tasks = len(target)
|
|
310
314
|
|
|
311
315
|
if not self.pc_dense_features and not self.pc_sparse_features and not self.pc_sequence_features:
|
|
312
316
|
raise ValueError("POSO requires at least one PC feature for personalization.")
|
|
@@ -320,13 +324,12 @@ class POSO(BaseModel):
|
|
|
320
324
|
sparse_features=sparse_features,
|
|
321
325
|
sequence_features=sequence_features,
|
|
322
326
|
target=target,
|
|
323
|
-
task=task,
|
|
327
|
+
task=task or self.default_task,
|
|
324
328
|
device=device,
|
|
325
329
|
embedding_l1_reg=embedding_l1_reg,
|
|
326
330
|
dense_l1_reg=dense_l1_reg,
|
|
327
331
|
embedding_l2_reg=embedding_l2_reg,
|
|
328
332
|
dense_l2_reg=dense_l2_reg,
|
|
329
|
-
early_stop_patience=20,
|
|
330
333
|
**kwargs,
|
|
331
334
|
)
|
|
332
335
|
|
|
@@ -387,7 +390,7 @@ class POSO(BaseModel):
|
|
|
387
390
|
)
|
|
388
391
|
self.towers = nn.ModuleList([MLP(input_dim=self.mmoe.expert_output_dim, output_layer=True, **tower_params,) for tower_params in tower_params_list])
|
|
389
392
|
self.tower_heads = None
|
|
390
|
-
self.prediction_layer = PredictionLayer(task_type=self.
|
|
393
|
+
self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1] * self.num_tasks,)
|
|
391
394
|
include_modules = ["towers", "tower_heads"] if self.architecture == "mlp" else ["mmoe", "towers"]
|
|
392
395
|
self.register_regularization_weights(embedding_attr="embedding", include_modules=include_modules)
|
|
393
396
|
self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
|
|
@@ -53,9 +53,11 @@ class ShareBottom(BaseModel):
|
|
|
53
53
|
return "ShareBottom"
|
|
54
54
|
|
|
55
55
|
@property
|
|
56
|
-
def
|
|
57
|
-
|
|
58
|
-
|
|
56
|
+
def default_task(self):
|
|
57
|
+
num_tasks = getattr(self, "num_tasks", None)
|
|
58
|
+
if num_tasks is not None and num_tasks > 0:
|
|
59
|
+
return ['binary'] * num_tasks
|
|
60
|
+
return ['binary']
|
|
59
61
|
|
|
60
62
|
def __init__(self,
|
|
61
63
|
dense_features: list[DenseFeature],
|
|
@@ -64,7 +66,7 @@ class ShareBottom(BaseModel):
|
|
|
64
66
|
bottom_params: dict,
|
|
65
67
|
tower_params_list: list[dict],
|
|
66
68
|
target: list[str],
|
|
67
|
-
task: str | list[str] =
|
|
69
|
+
task: str | list[str] | None = None,
|
|
68
70
|
optimizer: str = "adam",
|
|
69
71
|
optimizer_params: dict = {},
|
|
70
72
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
@@ -76,18 +78,19 @@ class ShareBottom(BaseModel):
|
|
|
76
78
|
dense_l2_reg=1e-4,
|
|
77
79
|
**kwargs):
|
|
78
80
|
|
|
81
|
+
self.num_tasks = len(target)
|
|
82
|
+
|
|
79
83
|
super(ShareBottom, self).__init__(
|
|
80
84
|
dense_features=dense_features,
|
|
81
85
|
sparse_features=sparse_features,
|
|
82
86
|
sequence_features=sequence_features,
|
|
83
87
|
target=target,
|
|
84
|
-
task=task,
|
|
88
|
+
task=task or self.default_task,
|
|
85
89
|
device=device,
|
|
86
90
|
embedding_l1_reg=embedding_l1_reg,
|
|
87
91
|
dense_l1_reg=dense_l1_reg,
|
|
88
92
|
embedding_l2_reg=embedding_l2_reg,
|
|
89
93
|
dense_l2_reg=dense_l2_reg,
|
|
90
|
-
early_stop_patience=20,
|
|
91
94
|
**kwargs
|
|
92
95
|
)
|
|
93
96
|
|
|
@@ -120,7 +123,7 @@ class ShareBottom(BaseModel):
|
|
|
120
123
|
for tower_params in tower_params_list:
|
|
121
124
|
tower = MLP(input_dim=bottom_output_dim, output_layer=True, **tower_params)
|
|
122
125
|
self.towers.append(tower)
|
|
123
|
-
self.prediction_layer = PredictionLayer(task_type=self.
|
|
126
|
+
self.prediction_layer = PredictionLayer(task_type=self.default_task, task_dims=[1] * self.num_tasks)
|
|
124
127
|
# Register regularization weights
|
|
125
128
|
self.register_regularization_weights(embedding_attr='embedding', include_modules=['bottom', 'towers'])
|
|
126
129
|
self.compile(optimizer=optimizer, optimizer_params=optimizer_params, loss=loss, loss_params=loss_params)
|