nextrec 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +1 -1
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -5
- nextrec/basic/callback.py +1 -0
- nextrec/basic/features.py +30 -22
- nextrec/basic/layers.py +220 -106
- nextrec/basic/loggers.py +62 -43
- nextrec/basic/metrics.py +268 -119
- nextrec/basic/model.py +1082 -400
- nextrec/basic/session.py +10 -3
- nextrec/cli.py +498 -0
- nextrec/data/__init__.py +19 -25
- nextrec/data/batch_utils.py +11 -3
- nextrec/data/data_processing.py +51 -45
- nextrec/data/data_utils.py +26 -15
- nextrec/data/dataloader.py +272 -95
- nextrec/data/preprocessor.py +320 -199
- nextrec/loss/listwise.py +17 -9
- nextrec/loss/loss_utils.py +7 -8
- nextrec/loss/pairwise.py +2 -0
- nextrec/loss/pointwise.py +30 -12
- nextrec/models/generative/hstu.py +103 -38
- nextrec/models/match/dssm.py +82 -68
- nextrec/models/match/dssm_v2.py +72 -57
- nextrec/models/match/mind.py +175 -107
- nextrec/models/match/sdm.py +104 -87
- nextrec/models/match/youtube_dnn.py +73 -59
- nextrec/models/multi_task/esmm.py +53 -37
- nextrec/models/multi_task/mmoe.py +64 -45
- nextrec/models/multi_task/ple.py +101 -48
- nextrec/models/multi_task/poso.py +113 -36
- nextrec/models/multi_task/share_bottom.py +48 -35
- nextrec/models/ranking/afm.py +72 -37
- nextrec/models/ranking/autoint.py +72 -55
- nextrec/models/ranking/dcn.py +55 -35
- nextrec/models/ranking/dcn_v2.py +64 -23
- nextrec/models/ranking/deepfm.py +32 -22
- nextrec/models/ranking/dien.py +155 -99
- nextrec/models/ranking/din.py +85 -57
- nextrec/models/ranking/fibinet.py +52 -32
- nextrec/models/ranking/fm.py +29 -23
- nextrec/models/ranking/masknet.py +91 -29
- nextrec/models/ranking/pnn.py +31 -28
- nextrec/models/ranking/widedeep.py +34 -26
- nextrec/models/ranking/xdeepfm.py +60 -38
- nextrec/utils/__init__.py +59 -34
- nextrec/utils/config.py +490 -0
- nextrec/utils/device.py +30 -20
- nextrec/utils/distributed.py +36 -9
- nextrec/utils/embedding.py +1 -0
- nextrec/utils/feature.py +1 -0
- nextrec/utils/file.py +32 -11
- nextrec/utils/initializer.py +61 -16
- nextrec/utils/optimizer.py +25 -9
- nextrec/utils/synthetic_data.py +283 -165
- nextrec/utils/tensor.py +24 -13
- {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/METADATA +4 -4
- nextrec-0.4.2.dist-info/RECORD +69 -0
- nextrec-0.4.2.dist-info/entry_points.txt +2 -0
- nextrec-0.4.1.dist-info/RECORD +0 -66
- {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
- {nextrec-0.4.1.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/session.py
CHANGED
|
@@ -7,7 +7,7 @@ Author: Yang Zhou,zyaztec@gmail.com
|
|
|
7
7
|
import os
|
|
8
8
|
import tempfile
|
|
9
9
|
from dataclasses import dataclass
|
|
10
|
-
from datetime import datetime
|
|
10
|
+
from datetime import datetime
|
|
11
11
|
from pathlib import Path
|
|
12
12
|
|
|
13
13
|
__all__ = [
|
|
@@ -16,6 +16,7 @@ __all__ = [
|
|
|
16
16
|
"create_session",
|
|
17
17
|
]
|
|
18
18
|
|
|
19
|
+
|
|
19
20
|
@dataclass(frozen=True)
|
|
20
21
|
class Session:
|
|
21
22
|
"""Encapsulate standard folders for a NextRec experiment."""
|
|
@@ -35,7 +36,7 @@ class Session:
|
|
|
35
36
|
@property
|
|
36
37
|
def predictions_dir(self) -> Path:
|
|
37
38
|
return self._ensure_dir(self.root / "predictions")
|
|
38
|
-
|
|
39
|
+
|
|
39
40
|
@property
|
|
40
41
|
def processor_dir(self) -> Path:
|
|
41
42
|
return self._ensure_dir(self.root / "processor")
|
|
@@ -60,6 +61,7 @@ class Session:
|
|
|
60
61
|
path.mkdir(parents=True, exist_ok=True)
|
|
61
62
|
return path
|
|
62
63
|
|
|
64
|
+
|
|
63
65
|
def create_session(experiment_id: str | Path | None = None) -> Session:
|
|
64
66
|
|
|
65
67
|
if experiment_id is not None and str(experiment_id).strip():
|
|
@@ -86,6 +88,7 @@ def create_session(experiment_id: str | Path | None = None) -> Session:
|
|
|
86
88
|
|
|
87
89
|
return Session(experiment_id=exp_id, root=root, log_basename=log_basename)
|
|
88
90
|
|
|
91
|
+
|
|
89
92
|
def resolve_save_path(
|
|
90
93
|
path: str | os.PathLike | Path | None,
|
|
91
94
|
default_dir: str | Path,
|
|
@@ -129,7 +132,11 @@ def resolve_save_path(
|
|
|
129
132
|
base_dir = candidate
|
|
130
133
|
file_stem = default_name
|
|
131
134
|
else:
|
|
132
|
-
base_dir =
|
|
135
|
+
base_dir = (
|
|
136
|
+
candidate.parent
|
|
137
|
+
if candidate.parent not in (Path("."), Path(""))
|
|
138
|
+
else base_dir
|
|
139
|
+
)
|
|
133
140
|
file_stem = candidate.name or default_name
|
|
134
141
|
else:
|
|
135
142
|
file_stem = default_name
|
nextrec/cli.py
ADDED
|
@@ -0,0 +1,498 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Command-line interface for NextRec training and prediction.
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
NextRec supports a flexible training and prediction pipeline driven by configuration files.
|
|
6
|
+
After preparing the configuration YAML files for training and prediction, users can run the
|
|
7
|
+
following script to execute the desired operations.
|
|
8
|
+
|
|
9
|
+
Examples:
|
|
10
|
+
# Train a model
|
|
11
|
+
nextrec --mode=train --train_config=tutorials/iflytek/scripts/masknet/train_config.yaml
|
|
12
|
+
|
|
13
|
+
# Run prediction
|
|
14
|
+
nextrec --mode=predict --predict_config=tutorials/iflytek/scripts/masknet/predict_config.yaml
|
|
15
|
+
|
|
16
|
+
Date: create on 06/12/2025
|
|
17
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
import argparse
|
|
21
|
+
import logging
|
|
22
|
+
import pickle
|
|
23
|
+
import time
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
from typing import Any, Dict, List
|
|
26
|
+
|
|
27
|
+
import pandas as pd
|
|
28
|
+
|
|
29
|
+
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
30
|
+
from nextrec.data.data_utils import split_dict_random
|
|
31
|
+
from nextrec.data.dataloader import RecDataLoader
|
|
32
|
+
from nextrec.data.preprocessor import DataProcessor
|
|
33
|
+
from nextrec.utils.config import (
|
|
34
|
+
build_feature_objects,
|
|
35
|
+
build_model_instance,
|
|
36
|
+
extract_feature_groups,
|
|
37
|
+
register_processor_features,
|
|
38
|
+
resolve_path,
|
|
39
|
+
select_features,
|
|
40
|
+
)
|
|
41
|
+
from nextrec.utils.feature import normalize_to_list
|
|
42
|
+
from nextrec.utils.file import (
|
|
43
|
+
iter_file_chunks,
|
|
44
|
+
read_table,
|
|
45
|
+
read_yaml,
|
|
46
|
+
resolve_file_paths,
|
|
47
|
+
)
|
|
48
|
+
from nextrec.basic.loggers import setup_logger
|
|
49
|
+
|
|
50
|
+
logger = logging.getLogger(__name__)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def train_model(train_config_path: str) -> None:
|
|
54
|
+
"""
|
|
55
|
+
Train a NextRec model using the provided configuration file.
|
|
56
|
+
|
|
57
|
+
configuration file must specify the below sections:
|
|
58
|
+
- session: Session settings including id and artifact root
|
|
59
|
+
- data: Data settings including path, format, target, validation split
|
|
60
|
+
- dataloader: DataLoader settings including batch sizes and shuffling
|
|
61
|
+
- model_config: Path to the model configuration YAML file
|
|
62
|
+
- feature_config: Path to the feature configuration YAML file
|
|
63
|
+
- train: Training settings including optimizer, loss, metrics, epochs, etc.
|
|
64
|
+
"""
|
|
65
|
+
config_file = Path(train_config_path)
|
|
66
|
+
config_dir = config_file.resolve().parent
|
|
67
|
+
cfg = read_yaml(config_file)
|
|
68
|
+
|
|
69
|
+
# read session configuration
|
|
70
|
+
session_cfg = cfg.get("session", {}) or {}
|
|
71
|
+
session_id = session_cfg.get("id", "nextrec_cli_session")
|
|
72
|
+
artifact_root = Path(session_cfg.get("artifact_root", "nextrec_logs"))
|
|
73
|
+
session_dir = artifact_root / session_id
|
|
74
|
+
setup_logger(session_id=session_id)
|
|
75
|
+
|
|
76
|
+
processor_path = session_dir / "processor.pkl"
|
|
77
|
+
processor_path = Path(processor_path)
|
|
78
|
+
processor_path.parent.mkdir(parents=True, exist_ok=True)
|
|
79
|
+
|
|
80
|
+
data_cfg = cfg.get("data", {}) or {}
|
|
81
|
+
dataloader_cfg = cfg.get("dataloader", {}) or {}
|
|
82
|
+
streaming = bool(data_cfg.get("streaming", False))
|
|
83
|
+
dataloader_chunk_size = dataloader_cfg.get("chunk_size", 20000)
|
|
84
|
+
|
|
85
|
+
# train data
|
|
86
|
+
data_path = resolve_path(data_cfg["path"], config_dir)
|
|
87
|
+
target = normalize_to_list(data_cfg["target"])
|
|
88
|
+
file_paths: List[str] = []
|
|
89
|
+
file_type: str | None = None
|
|
90
|
+
streaming_train_files: List[str] | None = None
|
|
91
|
+
streaming_valid_files: List[str] | None = None
|
|
92
|
+
|
|
93
|
+
feature_cfg_path = resolve_path(
|
|
94
|
+
cfg.get("feature_config", "feature_config.yaml"), config_dir
|
|
95
|
+
)
|
|
96
|
+
model_cfg_path = resolve_path(
|
|
97
|
+
cfg.get("model_config", "model_config.yaml"), config_dir
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
feature_cfg = read_yaml(feature_cfg_path)
|
|
101
|
+
model_cfg = read_yaml(model_cfg_path)
|
|
102
|
+
|
|
103
|
+
if streaming:
|
|
104
|
+
file_paths, file_type = resolve_file_paths(str(data_path))
|
|
105
|
+
first_file = file_paths[0]
|
|
106
|
+
first_chunk_size = max(1, min(dataloader_chunk_size, 1000))
|
|
107
|
+
chunk_iter = iter_file_chunks(first_file, file_type, first_chunk_size)
|
|
108
|
+
try:
|
|
109
|
+
first_chunk = next(chunk_iter)
|
|
110
|
+
except StopIteration as exc:
|
|
111
|
+
raise ValueError(f"Data file is empty: {first_file}") from exc
|
|
112
|
+
df_columns = list(first_chunk.columns)
|
|
113
|
+
|
|
114
|
+
else:
|
|
115
|
+
df = read_table(data_path, data_cfg.get("format"))
|
|
116
|
+
df_columns = list(df.columns)
|
|
117
|
+
|
|
118
|
+
# for some models have independent feature groups, we need to extract them here
|
|
119
|
+
feature_groups, grouped_columns = extract_feature_groups(feature_cfg, df_columns)
|
|
120
|
+
if feature_groups:
|
|
121
|
+
model_cfg.setdefault("params", {})
|
|
122
|
+
model_cfg["params"].setdefault("feature_groups", feature_groups)
|
|
123
|
+
|
|
124
|
+
dense_names, sparse_names, sequence_names = select_features(feature_cfg, df_columns)
|
|
125
|
+
used_columns = (
|
|
126
|
+
dense_names + sparse_names + sequence_names + grouped_columns + target
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# keep order but drop duplicates
|
|
130
|
+
seen = set()
|
|
131
|
+
unique_used_columns = []
|
|
132
|
+
for col in used_columns:
|
|
133
|
+
if col not in seen:
|
|
134
|
+
unique_used_columns.append(col)
|
|
135
|
+
seen.add(col)
|
|
136
|
+
|
|
137
|
+
processor = DataProcessor()
|
|
138
|
+
register_processor_features(
|
|
139
|
+
processor, feature_cfg, dense_names, sparse_names, sequence_names
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
if streaming:
|
|
143
|
+
processor.fit(str(data_path), chunk_size=dataloader_chunk_size)
|
|
144
|
+
processed = None
|
|
145
|
+
df = None # type: ignore[assignment]
|
|
146
|
+
else:
|
|
147
|
+
df = df[unique_used_columns]
|
|
148
|
+
processor.fit(df)
|
|
149
|
+
processed = processor.transform(df, return_dict=True)
|
|
150
|
+
|
|
151
|
+
processor.save(processor_path)
|
|
152
|
+
dense_features, sparse_features, sequence_features = build_feature_objects(
|
|
153
|
+
processor,
|
|
154
|
+
feature_cfg,
|
|
155
|
+
dense_names,
|
|
156
|
+
sparse_names,
|
|
157
|
+
sequence_names,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# Check if validation dataset path is specified
|
|
161
|
+
val_data_path = data_cfg.get("val_path") or data_cfg.get("valid_path")
|
|
162
|
+
if streaming:
|
|
163
|
+
if not file_paths:
|
|
164
|
+
file_paths, file_type = resolve_file_paths(str(data_path))
|
|
165
|
+
streaming_train_files = file_paths
|
|
166
|
+
streaming_valid_ratio = data_cfg.get("valid_ratio")
|
|
167
|
+
if val_data_path:
|
|
168
|
+
streaming_valid_files = None
|
|
169
|
+
elif streaming_valid_ratio is not None:
|
|
170
|
+
ratio = float(streaming_valid_ratio)
|
|
171
|
+
if not (0 < ratio < 1):
|
|
172
|
+
raise ValueError(
|
|
173
|
+
f"[NextRec CLI Error] Valid_ratio must be between 0 and 1, current value is {streaming_valid_ratio}"
|
|
174
|
+
)
|
|
175
|
+
total_files = len(file_paths)
|
|
176
|
+
if total_files < 2:
|
|
177
|
+
raise ValueError(
|
|
178
|
+
"[NextRec CLI Error] Must provide val_path or increase the number of data files. At least 2 files are required for streaming validation split."
|
|
179
|
+
)
|
|
180
|
+
val_count = max(1, int(round(total_files * ratio)))
|
|
181
|
+
if val_count >= total_files:
|
|
182
|
+
val_count = total_files - 1
|
|
183
|
+
streaming_valid_files = file_paths[-val_count:]
|
|
184
|
+
streaming_train_files = file_paths[:-val_count]
|
|
185
|
+
logger.info(
|
|
186
|
+
"使用 valid_ratio=%.3f 切分文件: 训练 %d 个文件, 验证 %d 个文件",
|
|
187
|
+
ratio,
|
|
188
|
+
len(streaming_train_files),
|
|
189
|
+
len(streaming_valid_files),
|
|
190
|
+
)
|
|
191
|
+
train_data: Dict[str, Any]
|
|
192
|
+
valid_data: Dict[str, Any] | None
|
|
193
|
+
|
|
194
|
+
if val_data_path and not streaming:
|
|
195
|
+
# Use specified validation dataset path
|
|
196
|
+
logger.info("使用指定的验证集路径: %s", val_data_path)
|
|
197
|
+
val_data_resolved = resolve_path(val_data_path, config_dir)
|
|
198
|
+
val_df = read_table(val_data_resolved, data_cfg.get("format"))
|
|
199
|
+
val_df = val_df[unique_used_columns]
|
|
200
|
+
if not isinstance(processed, dict):
|
|
201
|
+
raise TypeError("Processed data must be a dictionary")
|
|
202
|
+
train_data = processed
|
|
203
|
+
valid_data_result = processor.transform(val_df, return_dict=True)
|
|
204
|
+
if not isinstance(valid_data_result, dict):
|
|
205
|
+
raise TypeError("Validation data must be a dictionary")
|
|
206
|
+
valid_data = valid_data_result
|
|
207
|
+
train_size = len(list(train_data.values())[0])
|
|
208
|
+
valid_size = len(list(valid_data.values())[0])
|
|
209
|
+
logger.info("训练集样本数: %s, 验证集样本数: %s", train_size, valid_size)
|
|
210
|
+
elif streaming:
|
|
211
|
+
train_data = None # type: ignore[assignment]
|
|
212
|
+
valid_data = None
|
|
213
|
+
if not val_data_path and not streaming_valid_files:
|
|
214
|
+
logger.info(
|
|
215
|
+
"流式训练模式,未指定验证集路径且未配置 valid_ratio,跳过验证集创建"
|
|
216
|
+
)
|
|
217
|
+
else:
|
|
218
|
+
# Split data using valid_ratio
|
|
219
|
+
logger.info("使用 valid_ratio 切分数据: %s", data_cfg.get("valid_ratio", 0.2))
|
|
220
|
+
if not isinstance(processed, dict):
|
|
221
|
+
raise TypeError("Processed data must be a dictionary for splitting")
|
|
222
|
+
train_data, valid_data = split_dict_random(
|
|
223
|
+
processed,
|
|
224
|
+
test_size=data_cfg.get("valid_ratio", 0.2),
|
|
225
|
+
random_state=data_cfg.get("random_state", 2024),
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
dataloader = RecDataLoader(
|
|
229
|
+
dense_features=dense_features,
|
|
230
|
+
sparse_features=sparse_features,
|
|
231
|
+
sequence_features=sequence_features,
|
|
232
|
+
target=target,
|
|
233
|
+
processor=processor if streaming else None,
|
|
234
|
+
)
|
|
235
|
+
if streaming:
|
|
236
|
+
train_stream_source = streaming_train_files or file_paths
|
|
237
|
+
train_loader = dataloader.create_dataloader(
|
|
238
|
+
data=train_stream_source,
|
|
239
|
+
batch_size=dataloader_cfg.get("train_batch_size", 512),
|
|
240
|
+
shuffle=dataloader_cfg.get("train_shuffle", True),
|
|
241
|
+
load_full=False,
|
|
242
|
+
chunk_size=dataloader_chunk_size,
|
|
243
|
+
)
|
|
244
|
+
valid_loader = None
|
|
245
|
+
if val_data_path:
|
|
246
|
+
val_data_resolved = resolve_path(val_data_path, config_dir)
|
|
247
|
+
valid_loader = dataloader.create_dataloader(
|
|
248
|
+
data=str(val_data_resolved),
|
|
249
|
+
batch_size=dataloader_cfg.get("valid_batch_size", 512),
|
|
250
|
+
shuffle=dataloader_cfg.get("valid_shuffle", False),
|
|
251
|
+
load_full=False,
|
|
252
|
+
chunk_size=dataloader_chunk_size,
|
|
253
|
+
)
|
|
254
|
+
elif streaming_valid_files:
|
|
255
|
+
valid_loader = dataloader.create_dataloader(
|
|
256
|
+
data=streaming_valid_files,
|
|
257
|
+
batch_size=dataloader_cfg.get("valid_batch_size", 512),
|
|
258
|
+
shuffle=dataloader_cfg.get("valid_shuffle", False),
|
|
259
|
+
load_full=False,
|
|
260
|
+
chunk_size=dataloader_chunk_size,
|
|
261
|
+
)
|
|
262
|
+
else:
|
|
263
|
+
train_loader = dataloader.create_dataloader(
|
|
264
|
+
data=train_data,
|
|
265
|
+
batch_size=dataloader_cfg.get("train_batch_size", 512),
|
|
266
|
+
shuffle=dataloader_cfg.get("train_shuffle", True),
|
|
267
|
+
)
|
|
268
|
+
valid_loader = dataloader.create_dataloader(
|
|
269
|
+
data=valid_data,
|
|
270
|
+
batch_size=dataloader_cfg.get("valid_batch_size", 512),
|
|
271
|
+
shuffle=dataloader_cfg.get("valid_shuffle", False),
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
model_cfg.setdefault("session_id", session_id)
|
|
275
|
+
train_cfg = cfg.get("train", {}) or {}
|
|
276
|
+
device = train_cfg.get("device", model_cfg.get("device", "cpu"))
|
|
277
|
+
model = build_model_instance(
|
|
278
|
+
model_cfg,
|
|
279
|
+
model_cfg_path,
|
|
280
|
+
dense_features,
|
|
281
|
+
sparse_features,
|
|
282
|
+
sequence_features,
|
|
283
|
+
target,
|
|
284
|
+
device,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
model.compile(
|
|
288
|
+
optimizer=train_cfg.get("optimizer", "adam"),
|
|
289
|
+
optimizer_params=train_cfg.get("optimizer_params", {}),
|
|
290
|
+
loss=train_cfg.get("loss", "focal"),
|
|
291
|
+
loss_params=train_cfg.get("loss_params", {}),
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
model.fit(
|
|
295
|
+
train_data=train_loader,
|
|
296
|
+
valid_data=valid_loader,
|
|
297
|
+
metrics=train_cfg.get("metrics", ["auc", "recall", "precision"]),
|
|
298
|
+
epochs=train_cfg.get("epochs", 1),
|
|
299
|
+
batch_size=train_cfg.get(
|
|
300
|
+
"batch_size", dataloader_cfg.get("train_batch_size", 512)
|
|
301
|
+
),
|
|
302
|
+
shuffle=train_cfg.get("shuffle", True),
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def predict_model(predict_config_path: str) -> None:
|
|
307
|
+
"""
|
|
308
|
+
Run prediction using a trained model and configuration file.
|
|
309
|
+
"""
|
|
310
|
+
config_file = Path(predict_config_path)
|
|
311
|
+
config_dir = config_file.resolve().parent
|
|
312
|
+
cfg = read_yaml(config_file)
|
|
313
|
+
|
|
314
|
+
session_cfg = cfg.get("session", {}) or {}
|
|
315
|
+
session_id = session_cfg.get("id", "masknet_tutorial")
|
|
316
|
+
artifact_root = Path(session_cfg.get("artifact_root", "nextrec_logs"))
|
|
317
|
+
session_dir = Path(cfg.get("checkpoint_path") or (artifact_root / session_id))
|
|
318
|
+
setup_logger(session_id=session_id)
|
|
319
|
+
|
|
320
|
+
processor_path = Path(session_dir / "processor.pkl")
|
|
321
|
+
if not processor_path.exists():
|
|
322
|
+
processor_path = session_dir / "processor" / "processor.pkl"
|
|
323
|
+
|
|
324
|
+
predict_cfg = cfg.get("predict", {}) or {}
|
|
325
|
+
model_cfg_path = resolve_path(
|
|
326
|
+
cfg.get("model_config", "model_config.yaml"), config_dir
|
|
327
|
+
)
|
|
328
|
+
feature_cfg_path = resolve_path(
|
|
329
|
+
cfg.get("feature_config", "feature_config.yaml"), config_dir
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
model_cfg = read_yaml(model_cfg_path)
|
|
333
|
+
feature_cfg = read_yaml(feature_cfg_path)
|
|
334
|
+
model_cfg.setdefault("session_id", session_id)
|
|
335
|
+
feature_groups_raw = feature_cfg.get("feature_groups") or {}
|
|
336
|
+
model_cfg.setdefault("params", {})
|
|
337
|
+
|
|
338
|
+
# attach feature_groups in predict phase to avoid missing bindings
|
|
339
|
+
model_cfg["params"]["feature_groups"] = feature_groups_raw
|
|
340
|
+
|
|
341
|
+
processor = DataProcessor.load(processor_path)
|
|
342
|
+
|
|
343
|
+
# Load checkpoint and ensure required parameters are passed
|
|
344
|
+
checkpoint_base = Path(session_dir)
|
|
345
|
+
if checkpoint_base.is_dir():
|
|
346
|
+
candidates = sorted(checkpoint_base.glob("*.model"))
|
|
347
|
+
if not candidates:
|
|
348
|
+
raise FileNotFoundError(
|
|
349
|
+
f"[NextRec CLI Error]: Unable to find model checkpoint: {checkpoint_base}"
|
|
350
|
+
)
|
|
351
|
+
model_file = candidates[-1]
|
|
352
|
+
config_dir_for_features = checkpoint_base
|
|
353
|
+
else:
|
|
354
|
+
model_file = (
|
|
355
|
+
checkpoint_base.with_suffix(".model")
|
|
356
|
+
if checkpoint_base.suffix == ""
|
|
357
|
+
else checkpoint_base
|
|
358
|
+
)
|
|
359
|
+
config_dir_for_features = model_file.parent
|
|
360
|
+
|
|
361
|
+
features_config_path = config_dir_for_features / "features_config.pkl"
|
|
362
|
+
if not features_config_path.exists():
|
|
363
|
+
raise FileNotFoundError(
|
|
364
|
+
f"[NextRec CLI Error]: Unable to find features_config.pkl: {features_config_path}"
|
|
365
|
+
)
|
|
366
|
+
with open(features_config_path, "rb") as f:
|
|
367
|
+
features_config = pickle.load(f)
|
|
368
|
+
|
|
369
|
+
all_features = features_config.get("all_features", [])
|
|
370
|
+
target_cols = features_config.get("target", [])
|
|
371
|
+
id_columns = features_config.get("id_columns", [])
|
|
372
|
+
|
|
373
|
+
dense_features = [f for f in all_features if isinstance(f, DenseFeature)]
|
|
374
|
+
sparse_features = [f for f in all_features if isinstance(f, SparseFeature)]
|
|
375
|
+
sequence_features = [f for f in all_features if isinstance(f, SequenceFeature)]
|
|
376
|
+
|
|
377
|
+
target_override = (
|
|
378
|
+
cfg.get("targets")
|
|
379
|
+
or model_cfg.get("targets")
|
|
380
|
+
or model_cfg.get("params", {}).get("targets")
|
|
381
|
+
or model_cfg.get("params", {}).get("target")
|
|
382
|
+
)
|
|
383
|
+
if target_override:
|
|
384
|
+
target_cols = normalize_to_list(target_override)
|
|
385
|
+
|
|
386
|
+
# Recompute feature_groups with available feature names to drive bindings
|
|
387
|
+
feature_group_names = [f.name for f in all_features if hasattr(f, "name")]
|
|
388
|
+
parsed_feature_groups, _ = extract_feature_groups(feature_cfg, feature_group_names)
|
|
389
|
+
if parsed_feature_groups:
|
|
390
|
+
model_cfg.setdefault("params", {})
|
|
391
|
+
model_cfg["params"]["feature_groups"] = parsed_feature_groups
|
|
392
|
+
|
|
393
|
+
model = build_model_instance(
|
|
394
|
+
model_cfg=model_cfg,
|
|
395
|
+
model_cfg_path=model_cfg_path,
|
|
396
|
+
dense_features=dense_features,
|
|
397
|
+
sparse_features=sparse_features,
|
|
398
|
+
sequence_features=sequence_features,
|
|
399
|
+
target=target_cols,
|
|
400
|
+
device=predict_cfg.get("device", "cpu"),
|
|
401
|
+
)
|
|
402
|
+
model.id_columns = id_columns
|
|
403
|
+
model.load_model(
|
|
404
|
+
model_file, map_location=predict_cfg.get("device", "cpu"), verbose=True
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
id_columns = []
|
|
408
|
+
if predict_cfg.get("id_column"):
|
|
409
|
+
id_columns = [predict_cfg["id_column"]]
|
|
410
|
+
model.id_columns = id_columns
|
|
411
|
+
|
|
412
|
+
rec_dataloader = RecDataLoader(
|
|
413
|
+
dense_features=model.dense_features,
|
|
414
|
+
sparse_features=model.sparse_features,
|
|
415
|
+
sequence_features=model.sequence_features,
|
|
416
|
+
target=None,
|
|
417
|
+
id_columns=id_columns or model.id_columns,
|
|
418
|
+
processor=processor,
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
data_path = resolve_path(predict_cfg["data_path"], config_dir)
|
|
422
|
+
batch_size = predict_cfg.get("batch_size", 512)
|
|
423
|
+
|
|
424
|
+
pred_loader = rec_dataloader.create_dataloader(
|
|
425
|
+
data=str(data_path),
|
|
426
|
+
batch_size=batch_size,
|
|
427
|
+
shuffle=False,
|
|
428
|
+
load_full=predict_cfg.get("load_full", False),
|
|
429
|
+
chunk_size=predict_cfg.get("chunk_size", 20000),
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
output_path = resolve_path(predict_cfg["output_path"], config_dir)
|
|
433
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
434
|
+
|
|
435
|
+
start = time.time()
|
|
436
|
+
model.predict(
|
|
437
|
+
data=pred_loader,
|
|
438
|
+
batch_size=batch_size,
|
|
439
|
+
include_ids=bool(id_columns),
|
|
440
|
+
return_dataframe=False,
|
|
441
|
+
save_path=output_path,
|
|
442
|
+
save_format=predict_cfg.get("save_format", "csv"),
|
|
443
|
+
)
|
|
444
|
+
duration = time.time() - start
|
|
445
|
+
logger.info(f"Prediction completed, results saved to: {output_path}")
|
|
446
|
+
logger.info(f"Total time: {duration:.2f} seconds")
|
|
447
|
+
|
|
448
|
+
preview_rows = predict_cfg.get("preview_rows", 0)
|
|
449
|
+
if preview_rows > 0:
|
|
450
|
+
try:
|
|
451
|
+
preview = pd.read_csv(output_path, nrows=preview_rows)
|
|
452
|
+
logger.info(f"Output preview:\n{preview}")
|
|
453
|
+
except Exception as exc: # pragma: no cover
|
|
454
|
+
logger.warning(f"Failed to read output preview: {exc}")
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def main() -> None:
|
|
458
|
+
"""Parse CLI arguments and dispatch to train or predict mode."""
|
|
459
|
+
parser = argparse.ArgumentParser(
|
|
460
|
+
description="NextRec: Training and Prediction Pipeline",
|
|
461
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
462
|
+
epilog="""
|
|
463
|
+
Examples:
|
|
464
|
+
# Train a model
|
|
465
|
+
nextrec --mode=train --train_config=configs/train_config.yaml
|
|
466
|
+
|
|
467
|
+
# Run prediction
|
|
468
|
+
nextrec --mode=predict --predict_config=configs/predict_config.yaml
|
|
469
|
+
""",
|
|
470
|
+
)
|
|
471
|
+
parser.add_argument(
|
|
472
|
+
"--mode",
|
|
473
|
+
choices=["train", "predict"],
|
|
474
|
+
required=True,
|
|
475
|
+
help="运行模式:train 或 predict",
|
|
476
|
+
)
|
|
477
|
+
parser.add_argument("--train_config", help="训练配置文件路径")
|
|
478
|
+
parser.add_argument("--predict_config", help="预测配置文件路径")
|
|
479
|
+
parser.add_argument(
|
|
480
|
+
"--config",
|
|
481
|
+
help="通用配置文件路径(已废弃,建议使用 --train_config 或 --predict_config)",
|
|
482
|
+
)
|
|
483
|
+
args = parser.parse_args()
|
|
484
|
+
|
|
485
|
+
if args.mode == "train":
|
|
486
|
+
config_path = args.train_config or args.config
|
|
487
|
+
if not config_path:
|
|
488
|
+
parser.error("train 模式需要提供 --train_config")
|
|
489
|
+
train_model(config_path)
|
|
490
|
+
else:
|
|
491
|
+
config_path = args.predict_config or args.config
|
|
492
|
+
if not config_path:
|
|
493
|
+
parser.error("predict 模式需要提供 --predict_config")
|
|
494
|
+
predict_model(config_path)
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
if __name__ == "__main__":
|
|
498
|
+
main()
|
nextrec/data/__init__.py
CHANGED
|
@@ -27,35 +27,29 @@ from nextrec.data import data_utils
|
|
|
27
27
|
|
|
28
28
|
__all__ = [
|
|
29
29
|
# Batch utilities
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
30
|
+
"collate_fn",
|
|
31
|
+
"batch_to_dict",
|
|
32
|
+
"stack_section",
|
|
34
33
|
# Data processing
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
34
|
+
"get_column_data",
|
|
35
|
+
"split_dict_random",
|
|
36
|
+
"build_eval_candidates",
|
|
37
|
+
"get_user_ids",
|
|
40
38
|
# File utilities
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
39
|
+
"resolve_file_paths",
|
|
40
|
+
"iter_file_chunks",
|
|
41
|
+
"read_table",
|
|
42
|
+
"load_dataframes",
|
|
43
|
+
"default_output_dir",
|
|
47
44
|
# DataLoader
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
45
|
+
"TensorDictDataset",
|
|
46
|
+
"FileDataset",
|
|
47
|
+
"RecDataLoader",
|
|
48
|
+
"build_tensors_from_data",
|
|
53
49
|
# Preprocessor
|
|
54
|
-
|
|
55
|
-
|
|
50
|
+
"DataProcessor",
|
|
56
51
|
# Features
|
|
57
|
-
|
|
58
|
-
|
|
52
|
+
"FeatureSet",
|
|
59
53
|
# Legacy module
|
|
60
|
-
|
|
54
|
+
"data_utils",
|
|
61
55
|
]
|
nextrec/data/batch_utils.py
CHANGED
|
@@ -9,16 +9,22 @@ import torch
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
from typing import Any, Mapping
|
|
11
11
|
|
|
12
|
+
|
|
12
13
|
def stack_section(batch: list[dict], section: str):
|
|
13
14
|
entries = [item.get(section) for item in batch if item.get(section) is not None]
|
|
14
15
|
if not entries:
|
|
15
16
|
return None
|
|
16
17
|
merged: dict = {}
|
|
17
18
|
for name in entries[0]: # type: ignore
|
|
18
|
-
tensors = [
|
|
19
|
+
tensors = [
|
|
20
|
+
item[section][name]
|
|
21
|
+
for item in batch
|
|
22
|
+
if item.get(section) is not None and name in item[section]
|
|
23
|
+
]
|
|
19
24
|
merged[name] = torch.stack(tensors, dim=0)
|
|
20
25
|
return merged
|
|
21
26
|
|
|
27
|
+
|
|
22
28
|
def collate_fn(batch):
|
|
23
29
|
"""
|
|
24
30
|
Collate a list of sample dicts into the unified batch format:
|
|
@@ -28,7 +34,7 @@ def collate_fn(batch):
|
|
|
28
34
|
"ids": {id_name: Tensor(B, ...)} or None,
|
|
29
35
|
}
|
|
30
36
|
Args: batch: List of samples from DataLoader
|
|
31
|
-
|
|
37
|
+
|
|
32
38
|
Returns: dict: Batched data in unified format
|
|
33
39
|
"""
|
|
34
40
|
if not batch:
|
|
@@ -72,7 +78,9 @@ def collate_fn(batch):
|
|
|
72
78
|
|
|
73
79
|
def batch_to_dict(batch_data: Any, include_ids: bool = True) -> dict:
|
|
74
80
|
if not (isinstance(batch_data, Mapping) and "features" in batch_data):
|
|
75
|
-
raise TypeError(
|
|
81
|
+
raise TypeError(
|
|
82
|
+
"[BaseModel-batch_to_dict Error] Batch data must be a dict with 'features' produced by the current DataLoader."
|
|
83
|
+
)
|
|
76
84
|
return {
|
|
77
85
|
"features": batch_data.get("features", {}),
|
|
78
86
|
"labels": batch_data.get("labels"),
|