nextrec 0.4.17__py3-none-any.whl → 0.4.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/heads.py +1 -3
- nextrec/basic/loggers.py +5 -5
- nextrec/basic/model.py +210 -82
- nextrec/cli.py +5 -5
- nextrec/data/dataloader.py +93 -95
- nextrec/data/preprocessor.py +108 -46
- nextrec/loss/grad_norm.py +13 -13
- nextrec/models/multi_task/esmm.py +9 -11
- nextrec/models/multi_task/mmoe.py +18 -18
- nextrec/models/multi_task/ple.py +33 -33
- nextrec/models/multi_task/poso.py +21 -20
- nextrec/models/multi_task/share_bottom.py +16 -16
- nextrec/models/ranking/afm.py +2 -2
- nextrec/models/ranking/autoint.py +2 -2
- nextrec/models/ranking/dcn.py +2 -2
- nextrec/models/ranking/dcn_v2.py +2 -2
- nextrec/models/ranking/deepfm.py +2 -2
- nextrec/models/ranking/eulernet.py +2 -2
- nextrec/models/ranking/ffm.py +2 -2
- nextrec/models/ranking/fm.py +2 -2
- nextrec/models/ranking/lr.py +2 -2
- nextrec/models/ranking/masknet.py +2 -4
- nextrec/models/ranking/pnn.py +3 -3
- nextrec/models/ranking/widedeep.py +6 -7
- nextrec/models/ranking/xdeepfm.py +3 -3
- nextrec/utils/console.py +1 -1
- nextrec/utils/data.py +154 -32
- nextrec/utils/model.py +86 -1
- {nextrec-0.4.17.dist-info → nextrec-0.4.19.dist-info}/METADATA +8 -7
- {nextrec-0.4.17.dist-info → nextrec-0.4.19.dist-info}/RECORD +34 -34
- {nextrec-0.4.17.dist-info → nextrec-0.4.19.dist-info}/WHEEL +0 -0
- {nextrec-0.4.17.dist-info → nextrec-0.4.19.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.17.dist-info → nextrec-0.4.19.dist-info}/licenses/LICENSE +0 -0
nextrec/models/ranking/deepfm.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 27/10/2025
|
|
3
|
-
Checkpoint: edit on
|
|
3
|
+
Checkpoint: edit on 23/12/2025
|
|
4
4
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
5
5
|
Reference:
|
|
6
6
|
[1] Guo H, Tang R, Ye Y, et al. DeepFM: A factorization-machine based neural network
|
|
@@ -134,4 +134,4 @@ class DeepFM(BaseModel):
|
|
|
134
134
|
y_deep = self.mlp(input_deep) # [B, 1]
|
|
135
135
|
|
|
136
136
|
y = y_linear + y_fm + y_deep
|
|
137
|
-
return self.prediction_layer(y)
|
|
137
|
+
return self.prediction_layer(y)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 09/11/2025
|
|
3
|
-
Checkpoint: edit on
|
|
3
|
+
Checkpoint: edit on 23/12/2025
|
|
4
4
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
5
5
|
Reference:
|
|
6
6
|
[1] Zhao Z, Zhang H, Tang H, et al. EulerNet: Efficient and Effective Feature
|
|
@@ -332,4 +332,4 @@ class EulerNet(BaseModel):
|
|
|
332
332
|
r, p = layer(r, p)
|
|
333
333
|
r_flat = r.reshape(r.size(0), self.num_orders * self.embedding_dim)
|
|
334
334
|
p_flat = p.reshape(p.size(0), self.num_orders * self.embedding_dim)
|
|
335
|
-
return self.w(r_flat) + self.w_im(p_flat)
|
|
335
|
+
return self.w(r_flat) + self.w_im(p_flat)
|
nextrec/models/ranking/ffm.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 19/12/2025
|
|
3
|
-
Checkpoint: edit on
|
|
3
|
+
Checkpoint: edit on 23/12/2025
|
|
4
4
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
5
5
|
Reference:
|
|
6
6
|
[1] Juan Y, Zhuang Y, Chin W-S, et al. Field-aware Factorization Machines for CTR
|
|
@@ -273,4 +273,4 @@ class FFM(BaseModel):
|
|
|
273
273
|
)
|
|
274
274
|
|
|
275
275
|
y = y_linear + y_interaction
|
|
276
|
-
return self.prediction_layer(y)
|
|
276
|
+
return self.prediction_layer(y)
|
nextrec/models/ranking/fm.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 09/11/2025
|
|
3
|
-
Checkpoint: edit on
|
|
3
|
+
Checkpoint: edit on 23/12/2025
|
|
4
4
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
5
5
|
Reference:
|
|
6
6
|
[1] Rendle S. Factorization machines[C]//ICDM. 2010: 995-1000.
|
|
@@ -125,4 +125,4 @@ class FM(BaseModel):
|
|
|
125
125
|
y_linear = self.linear(input_fm.flatten(start_dim=1))
|
|
126
126
|
y_fm = self.fm(input_fm)
|
|
127
127
|
y = y_linear + y_fm
|
|
128
|
-
return self.prediction_layer(y)
|
|
128
|
+
return self.prediction_layer(y)
|
nextrec/models/ranking/lr.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 09/11/2025
|
|
3
|
-
Checkpoint: edit on
|
|
3
|
+
Checkpoint: edit on 23/12/2025
|
|
4
4
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
5
5
|
Reference:
|
|
6
6
|
[1] Hosmer D W, Lemeshow S, Sturdivant R X. Applied Logistic Regression.
|
|
@@ -116,4 +116,4 @@ class LR(BaseModel):
|
|
|
116
116
|
def forward(self, x):
|
|
117
117
|
input_linear = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
|
|
118
118
|
y = self.linear(input_linear)
|
|
119
|
-
return self.prediction_layer(y)
|
|
119
|
+
return self.prediction_layer(y)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 09/11/2025
|
|
3
|
-
Checkpoint: edit on
|
|
3
|
+
Checkpoint: edit on 23/12/2025
|
|
4
4
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
5
5
|
Reference:
|
|
6
6
|
[1] Wang Z, She Q, Zhang J. MaskNet: Introducing Feature-Wise
|
|
@@ -290,7 +290,6 @@ class MaskNet(BaseModel):
|
|
|
290
290
|
embedding_attr="embedding",
|
|
291
291
|
include_modules=["mask_blocks", "output_layer"],
|
|
292
292
|
)
|
|
293
|
-
# serial
|
|
294
293
|
else:
|
|
295
294
|
self.register_regularization_weights(
|
|
296
295
|
embedding_attr="embedding", include_modules=["mask_blocks", "final_mlp"]
|
|
@@ -315,7 +314,6 @@ class MaskNet(BaseModel):
|
|
|
315
314
|
block_outputs.append(h)
|
|
316
315
|
concat_hidden = torch.cat(block_outputs, dim=-1)
|
|
317
316
|
logit = self.final_mlp(concat_hidden) # [B, 1]
|
|
318
|
-
# serial
|
|
319
317
|
else:
|
|
320
318
|
hidden = self.first_block(field_emb, v_emb_flat)
|
|
321
319
|
hidden = self.block_dropout(hidden)
|
|
@@ -324,4 +322,4 @@ class MaskNet(BaseModel):
|
|
|
324
322
|
hidden = self.block_dropout(hidden)
|
|
325
323
|
logit = self.output_layer(hidden) # [B, 1]
|
|
326
324
|
y = self.prediction_layer(logit)
|
|
327
|
-
return y
|
|
325
|
+
return y
|
nextrec/models/ranking/pnn.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 09/11/2025
|
|
3
|
-
|
|
4
|
-
|
|
3
|
+
Checkpoint: edit on 23/12/2025
|
|
4
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
5
5
|
Reference:
|
|
6
6
|
[1] Qu Y, Cai H, Ren K, et al. Product-based neural networks for user response
|
|
7
7
|
prediction[C]//ICDM. 2016: 1149-1154. (https://arxiv.org/abs/1611.00144)
|
|
@@ -198,4 +198,4 @@ class PNN(BaseModel):
|
|
|
198
198
|
|
|
199
199
|
deep_input = torch.cat([linear_signal, product_signal], dim=1)
|
|
200
200
|
y = self.mlp(deep_input)
|
|
201
|
-
return self.prediction_layer(y)
|
|
201
|
+
return self.prediction_layer(y)
|
|
@@ -1,12 +1,11 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 09/11/2025
|
|
3
|
-
Checkpoint: edit on
|
|
4
|
-
Author:
|
|
5
|
-
Yang Zhou,zyaztec@gmail.com
|
|
3
|
+
Checkpoint: edit on 23/12/2025
|
|
4
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
6
5
|
Reference:
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
6
|
+
[1] Cheng H T, Koc L, Harmsen J, et al. Wide & Deep learning for recommender systems[C]
|
|
7
|
+
//Proceedings of the 1st Workshop on Deep Learning for Recommender Systems. 2016: 7-10.
|
|
8
|
+
(https://arxiv.org/abs/1606.07792)
|
|
10
9
|
|
|
11
10
|
Wide & Deep blends a linear wide component (memorization of cross features) with a
|
|
12
11
|
deep neural network (generalization) sharing the same feature space. The wide part
|
|
@@ -138,4 +137,4 @@ class WideDeep(BaseModel):
|
|
|
138
137
|
|
|
139
138
|
# Combine wide and deep
|
|
140
139
|
y = y_wide + y_deep
|
|
141
|
-
return self.prediction_layer(y)
|
|
140
|
+
return self.prediction_layer(y)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Date: create on 09/11/2025
|
|
3
|
-
|
|
4
|
-
Yang Zhou,zyaztec@gmail.com
|
|
3
|
+
Checkpoint: edit on 23/12/2025
|
|
4
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
5
5
|
Reference:
|
|
6
6
|
[1] Lian J, Zhou X, Zhang F, et al. xdeepfm: Combining explicit and implicit feature interactions
|
|
7
7
|
for recommender systems[C]//Proceedings of the 24th ACM SIGKDD international conference on
|
|
@@ -219,4 +219,4 @@ class xDeepFM(BaseModel):
|
|
|
219
219
|
|
|
220
220
|
# Combine all parts
|
|
221
221
|
y = y_linear + y_cin + y_deep
|
|
222
|
-
return self.prediction_layer(y)
|
|
222
|
+
return self.prediction_layer(y)
|
nextrec/utils/console.py
CHANGED
|
@@ -228,7 +228,7 @@ def group_metrics_by_task(
|
|
|
228
228
|
metrics: Mapping[str, Any] | None,
|
|
229
229
|
target_names: list[str] | str | None,
|
|
230
230
|
default_task_name: str = "overall",
|
|
231
|
-
)
|
|
231
|
+
):
|
|
232
232
|
if not metrics:
|
|
233
233
|
return [], {}
|
|
234
234
|
|
nextrec/utils/data.py
CHANGED
|
@@ -4,7 +4,7 @@ Data utilities for NextRec.
|
|
|
4
4
|
This module provides file I/O helpers and synthetic data generation.
|
|
5
5
|
|
|
6
6
|
Date: create on 19/12/2025
|
|
7
|
-
Checkpoint: edit on
|
|
7
|
+
Checkpoint: edit on 24/12/2025
|
|
8
8
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
9
9
|
"""
|
|
10
10
|
|
|
@@ -19,6 +19,46 @@ import pyarrow.parquet as pq
|
|
|
19
19
|
import torch
|
|
20
20
|
import yaml
|
|
21
21
|
|
|
22
|
+
FILE_FORMAT_CONFIG = {
|
|
23
|
+
"csv": {
|
|
24
|
+
"extension": [".csv", ".txt"],
|
|
25
|
+
"streaming": True,
|
|
26
|
+
},
|
|
27
|
+
"parquet": {
|
|
28
|
+
"extension": [".parquet"],
|
|
29
|
+
"streaming": True,
|
|
30
|
+
},
|
|
31
|
+
"feather": {
|
|
32
|
+
"extension": [".feather", ".ftr"],
|
|
33
|
+
"streaming": False,
|
|
34
|
+
},
|
|
35
|
+
"excel": {
|
|
36
|
+
"extension": [".xlsx", ".xls"],
|
|
37
|
+
"streaming": False,
|
|
38
|
+
},
|
|
39
|
+
"hdf5": {
|
|
40
|
+
"extension": [".h5", ".hdf5"],
|
|
41
|
+
"streaming": False,
|
|
42
|
+
},
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_file_format_from_extension(ext: str) -> str | None:
|
|
47
|
+
"""Get file format from extension."""
|
|
48
|
+
return {
|
|
49
|
+
ext.lstrip("."): fmt
|
|
50
|
+
for fmt, config in FILE_FORMAT_CONFIG.items()
|
|
51
|
+
for ext in config["extension"]
|
|
52
|
+
}.get(ext.lower().lstrip("."))
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def check_streaming_support(file_format: str) -> bool:
|
|
56
|
+
"""Check if a format supports streaming."""
|
|
57
|
+
file_format = file_format.lower()
|
|
58
|
+
if file_format not in FILE_FORMAT_CONFIG:
|
|
59
|
+
return False
|
|
60
|
+
return FILE_FORMAT_CONFIG[file_format].get("streaming", False)
|
|
61
|
+
|
|
22
62
|
|
|
23
63
|
def resolve_file_paths(path: str) -> tuple[list[str], str]:
|
|
24
64
|
"""
|
|
@@ -30,29 +70,39 @@ def resolve_file_paths(path: str) -> tuple[list[str], str]:
|
|
|
30
70
|
path_obj = Path(path)
|
|
31
71
|
|
|
32
72
|
if path_obj.is_file():
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
73
|
+
file_format = get_file_format_from_extension(path_obj.suffix)
|
|
74
|
+
if file_format is None:
|
|
75
|
+
raise ValueError(
|
|
76
|
+
f"Unsupported file extension: {path_obj.suffix}. "
|
|
77
|
+
f"Supported formats: {', '.join(FILE_FORMAT_CONFIG.keys())}"
|
|
78
|
+
)
|
|
79
|
+
return [str(path_obj)], file_format
|
|
39
80
|
|
|
40
81
|
if path_obj.is_dir():
|
|
41
82
|
collected_files = [p for p in path_obj.iterdir() if p.is_file()]
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
83
|
+
# Group files by format
|
|
84
|
+
format_groups: Dict[str, List[str]] = {}
|
|
85
|
+
for file in collected_files:
|
|
86
|
+
file_format = get_file_format_from_extension(file.suffix)
|
|
87
|
+
if file_format:
|
|
88
|
+
format_groups.setdefault(file_format, []).append(str(file))
|
|
89
|
+
|
|
90
|
+
if len(format_groups) > 1:
|
|
91
|
+
formats = ", ".join(format_groups.keys())
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"Directory contains mixed file formats: {formats}. "
|
|
94
|
+
"Please keep a single format per directory."
|
|
95
|
+
)
|
|
46
96
|
|
|
47
|
-
if
|
|
97
|
+
if not format_groups:
|
|
48
98
|
raise ValueError(
|
|
49
|
-
"
|
|
99
|
+
f"No supported data files found in directory: {path}. "
|
|
100
|
+
f"Supported formats: {', '.join(FILE_FORMAT_CONFIG.keys())}"
|
|
50
101
|
)
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
102
|
+
|
|
103
|
+
file_type = list(format_groups.keys())[0]
|
|
104
|
+
file_paths = format_groups[file_type]
|
|
54
105
|
file_paths.sort()
|
|
55
|
-
file_type = "csv" if csv_files else "parquet"
|
|
56
106
|
return file_paths, file_type
|
|
57
107
|
|
|
58
108
|
raise ValueError(f"Invalid path: {path}")
|
|
@@ -60,15 +110,55 @@ def resolve_file_paths(path: str) -> tuple[list[str], str]:
|
|
|
60
110
|
|
|
61
111
|
def read_table(path: str | Path, data_format: str | None = None) -> pd.DataFrame:
|
|
62
112
|
data_path = Path(path)
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
113
|
+
|
|
114
|
+
# Determine format
|
|
115
|
+
if data_format:
|
|
116
|
+
fmt = data_format.lower()
|
|
117
|
+
elif data_path.is_dir():
|
|
118
|
+
_, fmt = resolve_file_paths(str(data_path))
|
|
119
|
+
else:
|
|
120
|
+
fmt = get_file_format_from_extension(data_path.suffix)
|
|
121
|
+
if fmt is None:
|
|
122
|
+
raise ValueError(
|
|
123
|
+
f"Cannot determine format for {data_path}. "
|
|
124
|
+
f"Please specify data_format parameter."
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
if data_path.is_dir():
|
|
128
|
+
file_paths, _ = resolve_file_paths(str(data_path))
|
|
129
|
+
dataframes = [read_table(fp, fmt) for fp in file_paths]
|
|
130
|
+
if not dataframes:
|
|
131
|
+
raise ValueError(f"No supported data files found in directory: {data_path}")
|
|
132
|
+
if len(dataframes) == 1:
|
|
133
|
+
return dataframes[0]
|
|
134
|
+
return pd.concat(dataframes, ignore_index=True)
|
|
135
|
+
|
|
136
|
+
# Read based on format
|
|
137
|
+
try:
|
|
138
|
+
if fmt == "hdf5":
|
|
139
|
+
# HDF5 requires a key; use the first available key
|
|
140
|
+
with pd.HDFStore(data_path, mode="r") as store:
|
|
141
|
+
if len(store.keys()) == 0:
|
|
142
|
+
raise ValueError(f"HDF5 file {data_path} contains no datasets")
|
|
143
|
+
return pd.read_hdf(data_path, key=store.keys()[0])
|
|
144
|
+
reader = {
|
|
145
|
+
"parquet": pd.read_parquet,
|
|
146
|
+
"csv": lambda p: pd.read_csv(p, low_memory=False),
|
|
147
|
+
"feather": pd.read_feather,
|
|
148
|
+
"excel": pd.read_excel,
|
|
149
|
+
}.get(fmt)
|
|
150
|
+
if reader:
|
|
151
|
+
return reader(data_path)
|
|
152
|
+
raise ValueError(
|
|
153
|
+
f"Unsupported format: {fmt}. "
|
|
154
|
+
f"Supported: {', '.join(FILE_FORMAT_CONFIG.keys())}"
|
|
155
|
+
)
|
|
156
|
+
except ImportError as e:
|
|
157
|
+
raise ImportError(
|
|
158
|
+
f"Format '{fmt}' requires additional dependencies. "
|
|
159
|
+
f"Install with: pip install pandas[{fmt}] or check documentation. "
|
|
160
|
+
f"Original error: {e}"
|
|
161
|
+
) from e
|
|
72
162
|
|
|
73
163
|
|
|
74
164
|
def load_dataframes(file_paths: list[str], file_type: str) -> list[pd.DataFrame]:
|
|
@@ -78,12 +168,44 @@ def load_dataframes(file_paths: list[str], file_type: str) -> list[pd.DataFrame]
|
|
|
78
168
|
def iter_file_chunks(
|
|
79
169
|
file_path: str, file_type: str, chunk_size: int
|
|
80
170
|
) -> Generator[pd.DataFrame, None, None]:
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
171
|
+
"""Iterate over file in chunks for streaming reading.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
file_path: Path to the file
|
|
175
|
+
file_type: Format type (csv, parquet)
|
|
176
|
+
chunk_size: Number of rows per chunk
|
|
177
|
+
|
|
178
|
+
Yields:
|
|
179
|
+
DataFrame chunks
|
|
180
|
+
|
|
181
|
+
Raises:
|
|
182
|
+
ValueError: If format doesn't support streaming
|
|
183
|
+
"""
|
|
184
|
+
file_type = file_type.lower()
|
|
185
|
+
if not check_streaming_support(file_type):
|
|
186
|
+
raise ValueError(
|
|
187
|
+
f"Format '{file_type}' does not support streaming reads. "
|
|
188
|
+
"Formats with streaming support: csv, parquet"
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
try:
|
|
192
|
+
if file_type == "csv":
|
|
193
|
+
yield from pd.read_csv(file_path, chunksize=chunk_size)
|
|
194
|
+
elif file_type == "parquet":
|
|
195
|
+
parquet_file = pq.ParquetFile(file_path)
|
|
196
|
+
for batch in parquet_file.iter_batches(batch_size=chunk_size):
|
|
197
|
+
yield batch.to_pandas()
|
|
198
|
+
else:
|
|
199
|
+
raise ValueError(
|
|
200
|
+
f"Format '{file_type}' does not support streaming. "
|
|
201
|
+
f"Use read_table() to load the entire file into memory."
|
|
202
|
+
)
|
|
203
|
+
except ImportError as e:
|
|
204
|
+
raise ImportError(
|
|
205
|
+
f"Streaming format '{file_type}' requires additional dependencies. "
|
|
206
|
+
f"Install with: pip install pandas[{file_type}] pyarrow. "
|
|
207
|
+
f"Original error: {e}"
|
|
208
|
+
) from e
|
|
87
209
|
|
|
88
210
|
|
|
89
211
|
def default_output_dir(path: str) -> Path:
|
nextrec/utils/model.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
Model-related utilities for NextRec
|
|
3
3
|
|
|
4
4
|
Date: create on 03/12/2025
|
|
5
|
+
Checkpoint: edit on 24/12/2025
|
|
5
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
6
7
|
"""
|
|
7
8
|
|
|
@@ -9,6 +10,16 @@ from collections import OrderedDict
|
|
|
9
10
|
|
|
10
11
|
import torch
|
|
11
12
|
|
|
13
|
+
from nextrec.loss import (
|
|
14
|
+
ApproxNDCGLoss,
|
|
15
|
+
BPRLoss,
|
|
16
|
+
HingeLoss,
|
|
17
|
+
ListMLELoss,
|
|
18
|
+
ListNetLoss,
|
|
19
|
+
SampledSoftmaxLoss,
|
|
20
|
+
TripletLoss,
|
|
21
|
+
)
|
|
22
|
+
|
|
12
23
|
|
|
13
24
|
def merge_features(primary, secondary) -> list:
|
|
14
25
|
merged: OrderedDict[str, object] = OrderedDict()
|
|
@@ -53,6 +64,80 @@ def compute_pair_scores(model, data, batch_size: int = 512):
|
|
|
53
64
|
user_tensor = torch.as_tensor(user_emb, device=model.device)
|
|
54
65
|
item_tensor = torch.as_tensor(item_emb, device=model.device)
|
|
55
66
|
scores = model.compute_similarity(user_tensor, item_tensor)
|
|
56
|
-
|
|
67
|
+
mode = model.training_mode
|
|
68
|
+
if isinstance(mode, list):
|
|
69
|
+
mode = mode[0] if mode else "pointwise"
|
|
70
|
+
if mode == "pointwise":
|
|
57
71
|
scores = torch.sigmoid(scores)
|
|
58
72
|
return scores.detach().cpu().numpy()
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def prepare_ranking_targets(
|
|
76
|
+
y_pred: torch.Tensor, y_true: torch.Tensor
|
|
77
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
78
|
+
if y_pred.dim() == 1:
|
|
79
|
+
y_pred = y_pred.view(-1, 1)
|
|
80
|
+
if y_true.dim() == 1:
|
|
81
|
+
y_true = y_true.view(-1, 1)
|
|
82
|
+
if y_pred.shape != y_true.shape:
|
|
83
|
+
raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
|
|
84
|
+
return y_pred, y_true
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def split_pos_neg_scores(
|
|
88
|
+
scores: torch.Tensor, labels: torch.Tensor
|
|
89
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
90
|
+
if scores.dim() != 2 or labels.dim() != 2:
|
|
91
|
+
raise ValueError(
|
|
92
|
+
"[Ranking Error] pairwise/listwise training requires 2D scores and labels."
|
|
93
|
+
)
|
|
94
|
+
list_size = scores.size(1)
|
|
95
|
+
if list_size < 2:
|
|
96
|
+
raise ValueError(
|
|
97
|
+
"[Ranking Error] pairwise/listwise training requires list_size >= 2."
|
|
98
|
+
)
|
|
99
|
+
pos_mask = labels > 0
|
|
100
|
+
pos_counts = pos_mask.sum(dim=1)
|
|
101
|
+
neg_counts = list_size - pos_counts
|
|
102
|
+
if not torch.all(pos_counts == 1).item():
|
|
103
|
+
raise ValueError(
|
|
104
|
+
"[Ranking Error] pairwise/listwise with pos/neg split requires exactly one positive per row."
|
|
105
|
+
)
|
|
106
|
+
if not torch.all(neg_counts == list_size - 1).item():
|
|
107
|
+
raise ValueError(
|
|
108
|
+
"[Ranking Error] pairwise/listwise with pos/neg split requires at least one negative per row."
|
|
109
|
+
)
|
|
110
|
+
pos_scores = scores[pos_mask].view(-1)
|
|
111
|
+
neg_scores = scores[~pos_mask].view(scores.size(0), list_size - 1)
|
|
112
|
+
return pos_scores, neg_scores
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def compute_ranking_loss(
|
|
116
|
+
training_mode: str,
|
|
117
|
+
loss_fn: torch.nn.Module,
|
|
118
|
+
y_pred: torch.Tensor,
|
|
119
|
+
y_true: torch.Tensor,
|
|
120
|
+
) -> torch.Tensor:
|
|
121
|
+
y_pred, y_true = prepare_ranking_targets(y_pred, y_true)
|
|
122
|
+
if training_mode == "pairwise":
|
|
123
|
+
pos_scores, neg_scores = split_pos_neg_scores(y_pred, y_true)
|
|
124
|
+
if isinstance(loss_fn, (BPRLoss, HingeLoss, SampledSoftmaxLoss)):
|
|
125
|
+
loss = loss_fn(pos_scores, neg_scores)
|
|
126
|
+
elif isinstance(loss_fn, TripletLoss):
|
|
127
|
+
raise ValueError(
|
|
128
|
+
"[Ranking Error] TripletLoss expects embeddings, not scalar scores."
|
|
129
|
+
)
|
|
130
|
+
else:
|
|
131
|
+
loss = loss_fn(pos_scores, neg_scores)
|
|
132
|
+
elif training_mode == "listwise":
|
|
133
|
+
if isinstance(loss_fn, (ListNetLoss, ListMLELoss, ApproxNDCGLoss)):
|
|
134
|
+
loss = loss_fn(y_pred, y_true)
|
|
135
|
+
elif isinstance(loss_fn, SampledSoftmaxLoss):
|
|
136
|
+
pos_scores, neg_scores = split_pos_neg_scores(y_pred, y_true)
|
|
137
|
+
loss = loss_fn(pos_scores, neg_scores)
|
|
138
|
+
else:
|
|
139
|
+
loss = loss_fn(y_pred, y_true)
|
|
140
|
+
else:
|
|
141
|
+
raise ValueError(f"[Ranking Error] Unknown training mode: {training_mode}")
|
|
142
|
+
|
|
143
|
+
return loss
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nextrec
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.19
|
|
4
4
|
Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
|
|
5
5
|
Project-URL: Homepage, https://github.com/zerolovesea/NextRec
|
|
6
6
|
Project-URL: Repository, https://github.com/zerolovesea/NextRec
|
|
@@ -63,13 +63,12 @@ Description-Content-Type: text/markdown
|
|
|
63
63
|
|
|
64
64
|
<div align="center">
|
|
65
65
|
|
|
66
|
-
[](https://
|
|
66
|
+
[](https://pypistats.org/packages/nextrec)
|
|
67
67
|

|
|
68
68
|

|
|
69
|
-
|
|
70
69
|

|
|
71
|
-

|
|
71
|
+
[](https://deepwiki.com/zerolovesea/NextRec)
|
|
73
72
|
|
|
74
73
|
中文文档 | [English Version](README_en.md)
|
|
75
74
|
|
|
@@ -244,11 +243,13 @@ nextrec --mode=train --train_config=path/to/train_config.yaml
|
|
|
244
243
|
nextrec --mode=predict --predict_config=path/to/predict_config.yaml
|
|
245
244
|
```
|
|
246
245
|
|
|
247
|
-
|
|
246
|
+
预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
|
|
247
|
+
|
|
248
|
+
> 截止当前版本0.4.19,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
|
|
248
249
|
|
|
249
250
|
## 兼容平台
|
|
250
251
|
|
|
251
|
-
当前最新版本为0.4.
|
|
252
|
+
当前最新版本为0.4.19,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
|
|
252
253
|
|
|
253
254
|
| 平台 | 配置 |
|
|
254
255
|
|------|------|
|
|
@@ -1,24 +1,24 @@
|
|
|
1
1
|
nextrec/__init__.py,sha256=_M3oUqyuvQ5k8Th_3wId6hQ_caclh7M5ad51XN09m98,235
|
|
2
|
-
nextrec/__version__.py,sha256=
|
|
3
|
-
nextrec/cli.py,sha256=
|
|
2
|
+
nextrec/__version__.py,sha256=mfToEOXd_f_mFG0ts1iWNTvKhsjPQqwRFaAmcx61xJo,23
|
|
3
|
+
nextrec/cli.py,sha256=aJnv8A_K9xbzVRo0BPIUGg8vk__C4kniXweKanlh2G8,24326
|
|
4
4
|
nextrec/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
5
|
nextrec/basic/activation.py,sha256=uzTWfCOtBSkbu_Gk9XBNTj8__s241CaYLJk6l8nGX9I,2885
|
|
6
6
|
nextrec/basic/callback.py,sha256=nn1f8FG9c52vJ-gvwteqPbk3-1QuNS1vmhBlkENdb0I,14636
|
|
7
7
|
nextrec/basic/features.py,sha256=GyCUzGPuizUofrZSSOdqHK84YhnX4MGTdu7Cx2OGhUA,4654
|
|
8
|
-
nextrec/basic/heads.py,sha256=
|
|
8
|
+
nextrec/basic/heads.py,sha256=UcAfD6Tw3o-HVBHSqutvJSa56xX7ZIY_Fx2SNBCkb9E,3280
|
|
9
9
|
nextrec/basic/layers.py,sha256=ZM3Nka3e2cit3e3peL0ukJCMgKZK1ovNFfAWvVOwlos,28556
|
|
10
|
-
nextrec/basic/loggers.py,sha256=
|
|
10
|
+
nextrec/basic/loggers.py,sha256=kZsvWnKl45bIcGid59M6m-khTxb_2P6gvWAg4wohaO4,6509
|
|
11
11
|
nextrec/basic/metrics.py,sha256=1r6efTc9TpARNBt5X9ISoppTZflej6EdFkjPYHV-YZI,23162
|
|
12
|
-
nextrec/basic/model.py,sha256=
|
|
12
|
+
nextrec/basic/model.py,sha256=xueGxwl0A7XwdzgIqMuHavWon4glEsFKfu4e4QbGBWc,109348
|
|
13
13
|
nextrec/basic/session.py,sha256=UOG_-EgCOxvqZwCkiEd8sgNV2G1sm_HbzKYVQw8yYDI,4483
|
|
14
14
|
nextrec/data/__init__.py,sha256=YZQjpty1pDCM7q_YNmiA2sa5kbujUw26ObLHWjMPjKY,1194
|
|
15
15
|
nextrec/data/batch_utils.py,sha256=0bYGVX7RlhnHv_ZBaUngjDIpBNw-igCk98DgOsF7T6o,2879
|
|
16
16
|
nextrec/data/data_processing.py,sha256=lKXDBszrO5fJMAQetgSPr2mSQuzOluuz1eHV4jp0TDU,5538
|
|
17
17
|
nextrec/data/data_utils.py,sha256=0Ls1cnG9lBz0ovtyedw5vwp7WegGK_iF-F8e_3DEddo,880
|
|
18
|
-
nextrec/data/dataloader.py,sha256=
|
|
19
|
-
nextrec/data/preprocessor.py,sha256=
|
|
18
|
+
nextrec/data/dataloader.py,sha256=8kGZIrt88UnGFDOCNiMpCW5tn2H9D38GzfSaWHk-oqI,18943
|
|
19
|
+
nextrec/data/preprocessor.py,sha256=sIkuGNOpDTjtGy6gIlxxro3ktcUWOjIEqeWd2FN9s2g,47253
|
|
20
20
|
nextrec/loss/__init__.py,sha256=ZCgsfyR5YAecv6MdOsnUjkfacvZg2coQVjuKAfPvmRo,923
|
|
21
|
-
nextrec/loss/grad_norm.py,sha256=
|
|
21
|
+
nextrec/loss/grad_norm.py,sha256=1BU1uHh6CuNRc_M_bPP2mrVKOnUGQWv_tR_8-ETOJlg,8385
|
|
22
22
|
nextrec/loss/listwise.py,sha256=UT9vJCOTOQLogVwaeTV7Z5uxIYnngGdxk-p9e97MGkU,5744
|
|
23
23
|
nextrec/loss/loss_utils.py,sha256=xMmT_tWcKah_xcU3FzVMmSEzyZfxiMKZWUbwkAspcDg,4579
|
|
24
24
|
nextrec/loss/pairwise.py,sha256=X9yg-8pcPt2IWU0AiUhWAt3_4W_3wIF0uSdDYTdoPFY,3398
|
|
@@ -26,28 +26,28 @@ nextrec/loss/pointwise.py,sha256=o9J3OznY0hlbDsUXqn3k-BBzYiuUH5dopz8QBFqS_kQ,734
|
|
|
26
26
|
nextrec/models/generative/__init__.py,sha256=0MV3P-_ainPaTxmRBGWKUVCEt14KJvuvEHmRB3OQ1Fs,176
|
|
27
27
|
nextrec/models/generative/tiger.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
28
28
|
nextrec/models/multi_task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
29
|
-
nextrec/models/multi_task/esmm.py,sha256=
|
|
30
|
-
nextrec/models/multi_task/mmoe.py,sha256=
|
|
31
|
-
nextrec/models/multi_task/ple.py,sha256=
|
|
32
|
-
nextrec/models/multi_task/poso.py,sha256=
|
|
33
|
-
nextrec/models/multi_task/share_bottom.py,sha256=
|
|
29
|
+
nextrec/models/multi_task/esmm.py,sha256=WoqRtXZnAQaWHkPMGrV1PCY-BOn6tcx_B8okung84Ec,6540
|
|
30
|
+
nextrec/models/multi_task/mmoe.py,sha256=e6i1wG42LOrOWO602zEH7NJCXUiICcEK1xa_u5ZCQlI,8674
|
|
31
|
+
nextrec/models/multi_task/ple.py,sha256=KjhYR4hT91Q8XGesOp1hHw-2933sIZmJLBkI5QlrZ0s,13101
|
|
32
|
+
nextrec/models/multi_task/poso.py,sha256=cwLZeoFnC2Kiq7K9hxTuhVOlCTCN5RmK-7KKd9JIO_s,19241
|
|
33
|
+
nextrec/models/multi_task/share_bottom.py,sha256=QJrKvhq21dhV41m4MtbhCUTLUBf1uUFhDA8d4dacu9k,6600
|
|
34
34
|
nextrec/models/ranking/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
35
|
-
nextrec/models/ranking/afm.py,sha256=
|
|
36
|
-
nextrec/models/ranking/autoint.py,sha256=
|
|
37
|
-
nextrec/models/ranking/dcn.py,sha256=
|
|
38
|
-
nextrec/models/ranking/dcn_v2.py,sha256=
|
|
39
|
-
nextrec/models/ranking/deepfm.py,sha256=
|
|
35
|
+
nextrec/models/ranking/afm.py,sha256=lhHTyDh51yIEhEbNyxlRaJKSPleIrhcgeE2H6tvy5UA,10142
|
|
36
|
+
nextrec/models/ranking/autoint.py,sha256=FbKfhXq31mUpv9pQYv7b34d_HPxDrzAiiIgcUbi-Qow,8126
|
|
37
|
+
nextrec/models/ranking/dcn.py,sha256=u9Wvfjom7Aqy3TyyNJevJawrI7h-MB2DOgfj2j7keZE,7316
|
|
38
|
+
nextrec/models/ranking/dcn_v2.py,sha256=sx0W6JaU3ZTR2DNMz2vaqVeRhwwzzCFxmyGWBkas5P0,11180
|
|
39
|
+
nextrec/models/ranking/deepfm.py,sha256=RN9Kjgzl0ZZwlu3gHJQtyNs5-Qv7s2eNjsO1VZMf2hk,5217
|
|
40
40
|
nextrec/models/ranking/dien.py,sha256=URKWeOPeRD5OIWNsAxgVvbetOSrBHoq2eO5rR5UJ0jU,18971
|
|
41
41
|
nextrec/models/ranking/din.py,sha256=Y8v0gONRt1OZORmn0hqMuzMfkvX0Nz1gByJ94jo3MUw,9435
|
|
42
|
-
nextrec/models/ranking/eulernet.py,sha256=
|
|
43
|
-
nextrec/models/ranking/ffm.py,sha256=
|
|
42
|
+
nextrec/models/ranking/eulernet.py,sha256=UPTTbuog9d4kjrOPSwBiLhiBUMlXvkJm75qNzkbYG3U,12218
|
|
43
|
+
nextrec/models/ranking/ffm.py,sha256=8jgaCD2DLdxJ2qdm5DEOJ6ACKjopjIuOQvKYWWwgqOg,11279
|
|
44
44
|
nextrec/models/ranking/fibinet.py,sha256=VP0gNoQwoLKxniv2HmHzxlnR3YlrnQJt6--CwmAgsW4,7932
|
|
45
|
-
nextrec/models/ranking/fm.py,sha256=
|
|
46
|
-
nextrec/models/ranking/lr.py,sha256=
|
|
47
|
-
nextrec/models/ranking/masknet.py,sha256=
|
|
48
|
-
nextrec/models/ranking/pnn.py,sha256=
|
|
49
|
-
nextrec/models/ranking/widedeep.py,sha256=
|
|
50
|
-
nextrec/models/ranking/xdeepfm.py,sha256=
|
|
45
|
+
nextrec/models/ranking/fm.py,sha256=rNLF6Rs3AUc2H9elMsjQjHgvEsOe69Apb_T2j_CBmKg,4553
|
|
46
|
+
nextrec/models/ranking/lr.py,sha256=rPqDyRrJNamWE44419fwCSX0ad7MNSy8_EUKCyTnHDg,4025
|
|
47
|
+
nextrec/models/ranking/masknet.py,sha256=hvqe8FrU-IcPiXWFpvBzcRZc8Z8D6pMvpDola5uJGb0,12356
|
|
48
|
+
nextrec/models/ranking/pnn.py,sha256=wtej_QcEifQk3WFNKUl5_YZAD5yTLQoPDwCEkf9G1pg,8219
|
|
49
|
+
nextrec/models/ranking/widedeep.py,sha256=WAskeJbMA68_Qlxpl3vZc7KyrGVVyNQG_6LOH8VuxI4,5071
|
|
50
|
+
nextrec/models/ranking/xdeepfm.py,sha256=_Ng4q18IGNuHf_M1RpoAk4ypm4zLNvx3tw9uBwpAnc8,8205
|
|
51
51
|
nextrec/models/representation/__init__.py,sha256=O3QHMMXBszwM-mTl7bA3wawNZvDGet-QIv6Ys5GHGJ8,190
|
|
52
52
|
nextrec/models/representation/autorec.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
53
53
|
nextrec/models/representation/bpr.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -66,14 +66,14 @@ nextrec/models/sequential/hstu.py,sha256=P2Kl7HEL3afwiCApGKQ6UbUNO9eNXXrB10H7iiF
|
|
|
66
66
|
nextrec/models/sequential/sasrec.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
67
67
|
nextrec/utils/__init__.py,sha256=C-1l-suSsN_MlPlj_5LApyCRQLOao5l7bO0SccwKHw4,2598
|
|
68
68
|
nextrec/utils/config.py,sha256=VgCh5fto8HGodwXPJacenqjxre3Aw6tw-mntW9n3OYA,20044
|
|
69
|
-
nextrec/utils/console.py,sha256=
|
|
70
|
-
nextrec/utils/data.py,sha256=
|
|
69
|
+
nextrec/utils/console.py,sha256=jlZLCOopPHcspOx7ymWNcz3i77EgzP6v8bqZaiiBbuM,13597
|
|
70
|
+
nextrec/utils/data.py,sha256=pSL96mWjWfW_RKE-qlUSs9vfiYnFZAaRirzA6r7DB6s,24994
|
|
71
71
|
nextrec/utils/embedding.py,sha256=akAEc062MG2cD7VIOllHaqtwzAirQR2gq5iW7oKpGAU,1449
|
|
72
72
|
nextrec/utils/feature.py,sha256=rsUAv3ELyDpehVw8nPEEsLCCIjuKGTJJZuFaWB_wrPk,633
|
|
73
|
-
nextrec/utils/model.py,sha256=
|
|
73
|
+
nextrec/utils/model.py,sha256=fHvFciUuMOVcM1oWiRva4LcArRdZ1R5Uzml-COSqqvM,4688
|
|
74
74
|
nextrec/utils/torch_utils.py,sha256=AKfYbSOJjEw874xsDB5IO3Ote4X7vnqzt_E0jJny0o8,13468
|
|
75
|
-
nextrec-0.4.
|
|
76
|
-
nextrec-0.4.
|
|
77
|
-
nextrec-0.4.
|
|
78
|
-
nextrec-0.4.
|
|
79
|
-
nextrec-0.4.
|
|
75
|
+
nextrec-0.4.19.dist-info/METADATA,sha256=alzuH4O4ACST72cQX6SEbwc8wWGODXX0BV87b2lPUF8,21482
|
|
76
|
+
nextrec-0.4.19.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
77
|
+
nextrec-0.4.19.dist-info/entry_points.txt,sha256=NN-dNSdfMRTv86bNXM7d3ZEPW2BQC6bRi7QP7i9cIps,45
|
|
78
|
+
nextrec-0.4.19.dist-info/licenses/LICENSE,sha256=2fQfVKeafywkni7MYHyClC6RGGC3laLTXCNBx-ubtp0,1064
|
|
79
|
+
nextrec-0.4.19.dist-info/RECORD,,
|
|
File without changes
|