nextrec 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +1 -1
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -5
- nextrec/basic/callback.py +1 -0
- nextrec/basic/features.py +30 -22
- nextrec/basic/layers.py +250 -112
- nextrec/basic/loggers.py +63 -44
- nextrec/basic/metrics.py +270 -120
- nextrec/basic/model.py +1084 -402
- nextrec/basic/session.py +10 -3
- nextrec/cli.py +492 -0
- nextrec/data/__init__.py +19 -25
- nextrec/data/batch_utils.py +11 -3
- nextrec/data/data_processing.py +51 -45
- nextrec/data/data_utils.py +26 -15
- nextrec/data/dataloader.py +273 -96
- nextrec/data/preprocessor.py +320 -199
- nextrec/loss/listwise.py +17 -9
- nextrec/loss/loss_utils.py +7 -8
- nextrec/loss/pairwise.py +2 -0
- nextrec/loss/pointwise.py +30 -12
- nextrec/models/generative/hstu.py +103 -38
- nextrec/models/match/dssm.py +82 -68
- nextrec/models/match/dssm_v2.py +72 -57
- nextrec/models/match/mind.py +175 -107
- nextrec/models/match/sdm.py +104 -87
- nextrec/models/match/youtube_dnn.py +73 -59
- nextrec/models/multi_task/esmm.py +69 -46
- nextrec/models/multi_task/mmoe.py +91 -53
- nextrec/models/multi_task/ple.py +117 -58
- nextrec/models/multi_task/poso.py +163 -55
- nextrec/models/multi_task/share_bottom.py +63 -36
- nextrec/models/ranking/afm.py +80 -45
- nextrec/models/ranking/autoint.py +74 -57
- nextrec/models/ranking/dcn.py +110 -48
- nextrec/models/ranking/dcn_v2.py +265 -45
- nextrec/models/ranking/deepfm.py +39 -24
- nextrec/models/ranking/dien.py +335 -146
- nextrec/models/ranking/din.py +158 -92
- nextrec/models/ranking/fibinet.py +134 -52
- nextrec/models/ranking/fm.py +68 -26
- nextrec/models/ranking/masknet.py +95 -33
- nextrec/models/ranking/pnn.py +128 -58
- nextrec/models/ranking/widedeep.py +40 -28
- nextrec/models/ranking/xdeepfm.py +67 -40
- nextrec/utils/__init__.py +59 -34
- nextrec/utils/config.py +496 -0
- nextrec/utils/device.py +30 -20
- nextrec/utils/distributed.py +36 -9
- nextrec/utils/embedding.py +1 -0
- nextrec/utils/feature.py +1 -0
- nextrec/utils/file.py +33 -11
- nextrec/utils/initializer.py +61 -16
- nextrec/utils/model.py +22 -0
- nextrec/utils/optimizer.py +25 -9
- nextrec/utils/synthetic_data.py +283 -165
- nextrec/utils/tensor.py +24 -13
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/METADATA +53 -24
- nextrec-0.4.3.dist-info/RECORD +69 -0
- nextrec-0.4.3.dist-info/entry_points.txt +2 -0
- nextrec-0.4.1.dist-info/RECORD +0 -66
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/WHEEL +0 -0
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/session.py
CHANGED
|
@@ -7,7 +7,7 @@ Author: Yang Zhou,zyaztec@gmail.com
|
|
|
7
7
|
import os
|
|
8
8
|
import tempfile
|
|
9
9
|
from dataclasses import dataclass
|
|
10
|
-
from datetime import datetime
|
|
10
|
+
from datetime import datetime
|
|
11
11
|
from pathlib import Path
|
|
12
12
|
|
|
13
13
|
__all__ = [
|
|
@@ -16,6 +16,7 @@ __all__ = [
|
|
|
16
16
|
"create_session",
|
|
17
17
|
]
|
|
18
18
|
|
|
19
|
+
|
|
19
20
|
@dataclass(frozen=True)
|
|
20
21
|
class Session:
|
|
21
22
|
"""Encapsulate standard folders for a NextRec experiment."""
|
|
@@ -35,7 +36,7 @@ class Session:
|
|
|
35
36
|
@property
|
|
36
37
|
def predictions_dir(self) -> Path:
|
|
37
38
|
return self._ensure_dir(self.root / "predictions")
|
|
38
|
-
|
|
39
|
+
|
|
39
40
|
@property
|
|
40
41
|
def processor_dir(self) -> Path:
|
|
41
42
|
return self._ensure_dir(self.root / "processor")
|
|
@@ -60,6 +61,7 @@ class Session:
|
|
|
60
61
|
path.mkdir(parents=True, exist_ok=True)
|
|
61
62
|
return path
|
|
62
63
|
|
|
64
|
+
|
|
63
65
|
def create_session(experiment_id: str | Path | None = None) -> Session:
|
|
64
66
|
|
|
65
67
|
if experiment_id is not None and str(experiment_id).strip():
|
|
@@ -86,6 +88,7 @@ def create_session(experiment_id: str | Path | None = None) -> Session:
|
|
|
86
88
|
|
|
87
89
|
return Session(experiment_id=exp_id, root=root, log_basename=log_basename)
|
|
88
90
|
|
|
91
|
+
|
|
89
92
|
def resolve_save_path(
|
|
90
93
|
path: str | os.PathLike | Path | None,
|
|
91
94
|
default_dir: str | Path,
|
|
@@ -129,7 +132,11 @@ def resolve_save_path(
|
|
|
129
132
|
base_dir = candidate
|
|
130
133
|
file_stem = default_name
|
|
131
134
|
else:
|
|
132
|
-
base_dir =
|
|
135
|
+
base_dir = (
|
|
136
|
+
candidate.parent
|
|
137
|
+
if candidate.parent not in (Path("."), Path(""))
|
|
138
|
+
else base_dir
|
|
139
|
+
)
|
|
133
140
|
file_stem = candidate.name or default_name
|
|
134
141
|
else:
|
|
135
142
|
file_stem = default_name
|
nextrec/cli.py
ADDED
|
@@ -0,0 +1,492 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Command-line interface for NextRec training and prediction.
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
NextRec supports a flexible training and prediction pipeline driven by configuration files.
|
|
6
|
+
After preparing the configuration YAML files for training and prediction, users can run the
|
|
7
|
+
following script to execute the desired operations.
|
|
8
|
+
|
|
9
|
+
Examples:
|
|
10
|
+
# Train a model
|
|
11
|
+
nextrec --mode=train --train_config=nextrec_cli_preset/train_config.yaml
|
|
12
|
+
|
|
13
|
+
# Run prediction
|
|
14
|
+
nextrec --mode=predict --predict_config=nextrec_cli_preset/predict_config.yaml
|
|
15
|
+
|
|
16
|
+
Date: create on 06/12/2025
|
|
17
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
import argparse
|
|
21
|
+
import logging
|
|
22
|
+
import pickle
|
|
23
|
+
import time
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
from typing import Any, Dict, List
|
|
26
|
+
|
|
27
|
+
import pandas as pd
|
|
28
|
+
|
|
29
|
+
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
30
|
+
from nextrec.data.data_utils import split_dict_random
|
|
31
|
+
from nextrec.data.dataloader import RecDataLoader
|
|
32
|
+
from nextrec.data.preprocessor import DataProcessor
|
|
33
|
+
from nextrec.utils.config import (
|
|
34
|
+
build_feature_objects,
|
|
35
|
+
build_model_instance,
|
|
36
|
+
register_processor_features,
|
|
37
|
+
resolve_path,
|
|
38
|
+
select_features,
|
|
39
|
+
)
|
|
40
|
+
from nextrec.utils.feature import normalize_to_list
|
|
41
|
+
from nextrec.utils.file import (
|
|
42
|
+
iter_file_chunks,
|
|
43
|
+
read_table,
|
|
44
|
+
read_yaml,
|
|
45
|
+
resolve_file_paths,
|
|
46
|
+
)
|
|
47
|
+
from nextrec.basic.loggers import setup_logger
|
|
48
|
+
|
|
49
|
+
logger = logging.getLogger(__name__)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def train_model(train_config_path: str) -> None:
|
|
53
|
+
"""
|
|
54
|
+
Train a NextRec model using the provided configuration file.
|
|
55
|
+
|
|
56
|
+
configuration file must specify the below sections:
|
|
57
|
+
- session: Session settings including id and artifact root
|
|
58
|
+
- data: Data settings including path, format, target, validation split
|
|
59
|
+
- dataloader: DataLoader settings including batch sizes and shuffling
|
|
60
|
+
- model_config: Path to the model configuration YAML file
|
|
61
|
+
- feature_config: Path to the feature configuration YAML file
|
|
62
|
+
- train: Training settings including optimizer, loss, metrics, epochs, etc.
|
|
63
|
+
"""
|
|
64
|
+
config_file = Path(train_config_path)
|
|
65
|
+
config_dir = config_file.resolve().parent
|
|
66
|
+
cfg = read_yaml(config_file)
|
|
67
|
+
|
|
68
|
+
# read session configuration
|
|
69
|
+
session_cfg = cfg.get("session", {}) or {}
|
|
70
|
+
session_id = session_cfg.get("id", "nextrec_cli_session")
|
|
71
|
+
artifact_root = Path(session_cfg.get("artifact_root", "nextrec_logs"))
|
|
72
|
+
session_dir = artifact_root / session_id
|
|
73
|
+
setup_logger(session_id=session_id)
|
|
74
|
+
|
|
75
|
+
processor_path = session_dir / "processor.pkl"
|
|
76
|
+
processor_path = Path(processor_path)
|
|
77
|
+
processor_path.parent.mkdir(parents=True, exist_ok=True)
|
|
78
|
+
|
|
79
|
+
data_cfg = cfg.get("data", {}) or {}
|
|
80
|
+
dataloader_cfg = cfg.get("dataloader", {}) or {}
|
|
81
|
+
streaming = bool(data_cfg.get("streaming", False))
|
|
82
|
+
dataloader_chunk_size = dataloader_cfg.get("chunk_size", 20000)
|
|
83
|
+
|
|
84
|
+
# train data
|
|
85
|
+
data_path = resolve_path(data_cfg["path"], config_dir)
|
|
86
|
+
target = normalize_to_list(data_cfg["target"])
|
|
87
|
+
file_paths: List[str] = []
|
|
88
|
+
file_type: str | None = None
|
|
89
|
+
streaming_train_files: List[str] | None = None
|
|
90
|
+
streaming_valid_files: List[str] | None = None
|
|
91
|
+
|
|
92
|
+
feature_cfg_path = resolve_path(
|
|
93
|
+
cfg.get("feature_config", "feature_config.yaml"), config_dir
|
|
94
|
+
)
|
|
95
|
+
model_cfg_path = resolve_path(
|
|
96
|
+
cfg.get("model_config", "model_config.yaml"), config_dir
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
feature_cfg = read_yaml(feature_cfg_path)
|
|
100
|
+
model_cfg = read_yaml(model_cfg_path)
|
|
101
|
+
|
|
102
|
+
if streaming:
|
|
103
|
+
file_paths, file_type = resolve_file_paths(str(data_path))
|
|
104
|
+
first_file = file_paths[0]
|
|
105
|
+
first_chunk_size = max(1, min(dataloader_chunk_size, 1000))
|
|
106
|
+
chunk_iter = iter_file_chunks(first_file, file_type, first_chunk_size)
|
|
107
|
+
try:
|
|
108
|
+
first_chunk = next(chunk_iter)
|
|
109
|
+
except StopIteration as exc:
|
|
110
|
+
raise ValueError(f"Data file is empty: {first_file}") from exc
|
|
111
|
+
df_columns = list(first_chunk.columns)
|
|
112
|
+
|
|
113
|
+
else:
|
|
114
|
+
df = read_table(data_path, data_cfg.get("format"))
|
|
115
|
+
df_columns = list(df.columns)
|
|
116
|
+
|
|
117
|
+
dense_names, sparse_names, sequence_names = select_features(feature_cfg, df_columns)
|
|
118
|
+
|
|
119
|
+
# Extract id_column from data config for GAUC metrics
|
|
120
|
+
id_column = data_cfg.get("id_column") or data_cfg.get("user_id_column")
|
|
121
|
+
id_columns = [id_column] if id_column else []
|
|
122
|
+
|
|
123
|
+
used_columns = dense_names + sparse_names + sequence_names + target + id_columns
|
|
124
|
+
|
|
125
|
+
# keep order but drop duplicates
|
|
126
|
+
seen = set()
|
|
127
|
+
unique_used_columns = []
|
|
128
|
+
for col in used_columns:
|
|
129
|
+
if col not in seen:
|
|
130
|
+
unique_used_columns.append(col)
|
|
131
|
+
seen.add(col)
|
|
132
|
+
|
|
133
|
+
processor = DataProcessor()
|
|
134
|
+
register_processor_features(
|
|
135
|
+
processor, feature_cfg, dense_names, sparse_names, sequence_names
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
if streaming:
|
|
139
|
+
processor.fit(str(data_path), chunk_size=dataloader_chunk_size)
|
|
140
|
+
processed = None
|
|
141
|
+
df = None # type: ignore[assignment]
|
|
142
|
+
else:
|
|
143
|
+
df = df[unique_used_columns]
|
|
144
|
+
processor.fit(df)
|
|
145
|
+
processed = processor.transform(df, return_dict=True)
|
|
146
|
+
|
|
147
|
+
processor.save(processor_path)
|
|
148
|
+
dense_features, sparse_features, sequence_features = build_feature_objects(
|
|
149
|
+
processor,
|
|
150
|
+
feature_cfg,
|
|
151
|
+
dense_names,
|
|
152
|
+
sparse_names,
|
|
153
|
+
sequence_names,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# Check if validation dataset path is specified
|
|
157
|
+
val_data_path = data_cfg.get("val_path") or data_cfg.get("valid_path")
|
|
158
|
+
if streaming:
|
|
159
|
+
if not file_paths:
|
|
160
|
+
file_paths, file_type = resolve_file_paths(str(data_path))
|
|
161
|
+
streaming_train_files = file_paths
|
|
162
|
+
streaming_valid_ratio = data_cfg.get("valid_ratio")
|
|
163
|
+
if val_data_path:
|
|
164
|
+
streaming_valid_files = None
|
|
165
|
+
elif streaming_valid_ratio is not None:
|
|
166
|
+
ratio = float(streaming_valid_ratio)
|
|
167
|
+
if not (0 < ratio < 1):
|
|
168
|
+
raise ValueError(
|
|
169
|
+
f"[NextRec CLI Error] Valid_ratio must be between 0 and 1, current value is {streaming_valid_ratio}"
|
|
170
|
+
)
|
|
171
|
+
total_files = len(file_paths)
|
|
172
|
+
if total_files < 2:
|
|
173
|
+
raise ValueError(
|
|
174
|
+
"[NextRec CLI Error] Must provide val_path or increase the number of data files. At least 2 files are required for streaming validation split."
|
|
175
|
+
)
|
|
176
|
+
val_count = max(1, int(round(total_files * ratio)))
|
|
177
|
+
if val_count >= total_files:
|
|
178
|
+
val_count = total_files - 1
|
|
179
|
+
streaming_valid_files = file_paths[-val_count:]
|
|
180
|
+
streaming_train_files = file_paths[:-val_count]
|
|
181
|
+
logger.info(
|
|
182
|
+
f"Split files for streaming training and validation using valid_ratio={ratio:.3f}: training {len(streaming_train_files)} files, validation {len(streaming_valid_files)} files"
|
|
183
|
+
)
|
|
184
|
+
train_data: Dict[str, Any]
|
|
185
|
+
valid_data: Dict[str, Any] | None
|
|
186
|
+
|
|
187
|
+
if val_data_path and not streaming:
|
|
188
|
+
# Use specified validation dataset path
|
|
189
|
+
logger.info(
|
|
190
|
+
f"Validation using specified validation dataset path: {val_data_path}"
|
|
191
|
+
)
|
|
192
|
+
val_data_resolved = resolve_path(val_data_path, config_dir)
|
|
193
|
+
val_df = read_table(val_data_resolved, data_cfg.get("format"))
|
|
194
|
+
val_df = val_df[unique_used_columns]
|
|
195
|
+
if not isinstance(processed, dict):
|
|
196
|
+
raise TypeError("Processed data must be a dictionary")
|
|
197
|
+
train_data = processed
|
|
198
|
+
valid_data_result = processor.transform(val_df, return_dict=True)
|
|
199
|
+
if not isinstance(valid_data_result, dict):
|
|
200
|
+
raise TypeError("Validation data must be a dictionary")
|
|
201
|
+
valid_data = valid_data_result
|
|
202
|
+
train_size = len(list(train_data.values())[0])
|
|
203
|
+
valid_size = len(list(valid_data.values())[0])
|
|
204
|
+
logger.info(
|
|
205
|
+
f"Sample count - Training set: {train_size}, Validation set: {valid_size}"
|
|
206
|
+
)
|
|
207
|
+
elif streaming:
|
|
208
|
+
train_data = None # type: ignore[assignment]
|
|
209
|
+
valid_data = None
|
|
210
|
+
if not val_data_path and not streaming_valid_files:
|
|
211
|
+
logger.info(
|
|
212
|
+
"Streaming training mode: No validation dataset path specified and valid_ratio not configured, skipping validation dataset creation"
|
|
213
|
+
)
|
|
214
|
+
else:
|
|
215
|
+
# Split data using valid_ratio
|
|
216
|
+
logger.info(
|
|
217
|
+
f"Splitting data using valid_ratio: {data_cfg.get('valid_ratio', 0.2)}"
|
|
218
|
+
)
|
|
219
|
+
if not isinstance(processed, dict):
|
|
220
|
+
raise TypeError("Processed data must be a dictionary for splitting")
|
|
221
|
+
train_data, valid_data = split_dict_random(
|
|
222
|
+
processed,
|
|
223
|
+
test_size=data_cfg.get("valid_ratio", 0.2),
|
|
224
|
+
random_state=data_cfg.get("random_state", 2024),
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
dataloader = RecDataLoader(
|
|
228
|
+
dense_features=dense_features,
|
|
229
|
+
sparse_features=sparse_features,
|
|
230
|
+
sequence_features=sequence_features,
|
|
231
|
+
target=target,
|
|
232
|
+
id_columns=id_columns,
|
|
233
|
+
processor=processor if streaming else None,
|
|
234
|
+
)
|
|
235
|
+
if streaming:
|
|
236
|
+
train_stream_source = streaming_train_files or file_paths
|
|
237
|
+
train_loader = dataloader.create_dataloader(
|
|
238
|
+
data=train_stream_source,
|
|
239
|
+
batch_size=dataloader_cfg.get("train_batch_size", 512),
|
|
240
|
+
shuffle=dataloader_cfg.get("train_shuffle", True),
|
|
241
|
+
load_full=False,
|
|
242
|
+
chunk_size=dataloader_chunk_size,
|
|
243
|
+
num_workers=dataloader_cfg.get("num_workers", 0),
|
|
244
|
+
)
|
|
245
|
+
valid_loader = None
|
|
246
|
+
if val_data_path:
|
|
247
|
+
val_data_resolved = resolve_path(val_data_path, config_dir)
|
|
248
|
+
valid_loader = dataloader.create_dataloader(
|
|
249
|
+
data=str(val_data_resolved),
|
|
250
|
+
batch_size=dataloader_cfg.get("valid_batch_size", 512),
|
|
251
|
+
shuffle=dataloader_cfg.get("valid_shuffle", False),
|
|
252
|
+
load_full=False,
|
|
253
|
+
chunk_size=dataloader_chunk_size,
|
|
254
|
+
num_workers=dataloader_cfg.get("num_workers", 0),
|
|
255
|
+
)
|
|
256
|
+
elif streaming_valid_files:
|
|
257
|
+
valid_loader = dataloader.create_dataloader(
|
|
258
|
+
data=streaming_valid_files,
|
|
259
|
+
batch_size=dataloader_cfg.get("valid_batch_size", 512),
|
|
260
|
+
shuffle=dataloader_cfg.get("valid_shuffle", False),
|
|
261
|
+
load_full=False,
|
|
262
|
+
chunk_size=dataloader_chunk_size,
|
|
263
|
+
num_workers=dataloader_cfg.get("num_workers", 0),
|
|
264
|
+
)
|
|
265
|
+
else:
|
|
266
|
+
train_loader = dataloader.create_dataloader(
|
|
267
|
+
data=train_data,
|
|
268
|
+
batch_size=dataloader_cfg.get("train_batch_size", 512),
|
|
269
|
+
shuffle=dataloader_cfg.get("train_shuffle", True),
|
|
270
|
+
num_workers=dataloader_cfg.get("num_workers", 0),
|
|
271
|
+
)
|
|
272
|
+
valid_loader = dataloader.create_dataloader(
|
|
273
|
+
data=valid_data,
|
|
274
|
+
batch_size=dataloader_cfg.get("valid_batch_size", 512),
|
|
275
|
+
shuffle=dataloader_cfg.get("valid_shuffle", False),
|
|
276
|
+
num_workers=dataloader_cfg.get("num_workers", 0),
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
model_cfg.setdefault("session_id", session_id)
|
|
280
|
+
train_cfg = cfg.get("train", {}) or {}
|
|
281
|
+
device = train_cfg.get("device", model_cfg.get("device", "cpu"))
|
|
282
|
+
model = build_model_instance(
|
|
283
|
+
model_cfg,
|
|
284
|
+
model_cfg_path,
|
|
285
|
+
dense_features,
|
|
286
|
+
sparse_features,
|
|
287
|
+
sequence_features,
|
|
288
|
+
target,
|
|
289
|
+
device,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
model.compile(
|
|
293
|
+
optimizer=train_cfg.get("optimizer", "adam"),
|
|
294
|
+
optimizer_params=train_cfg.get("optimizer_params", {}),
|
|
295
|
+
loss=train_cfg.get("loss", "focal"),
|
|
296
|
+
loss_params=train_cfg.get("loss_params", {}),
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
model.fit(
|
|
300
|
+
train_data=train_loader,
|
|
301
|
+
valid_data=valid_loader,
|
|
302
|
+
metrics=train_cfg.get("metrics", ["auc", "recall", "precision"]),
|
|
303
|
+
epochs=train_cfg.get("epochs", 1),
|
|
304
|
+
batch_size=train_cfg.get(
|
|
305
|
+
"batch_size", dataloader_cfg.get("train_batch_size", 512)
|
|
306
|
+
),
|
|
307
|
+
shuffle=train_cfg.get("shuffle", True),
|
|
308
|
+
num_workers=dataloader_cfg.get("num_workers", 0),
|
|
309
|
+
user_id_column=id_column,
|
|
310
|
+
tensorboard=False,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def predict_model(predict_config_path: str) -> None:
|
|
315
|
+
"""
|
|
316
|
+
Run prediction using a trained model and configuration file.
|
|
317
|
+
"""
|
|
318
|
+
config_file = Path(predict_config_path)
|
|
319
|
+
config_dir = config_file.resolve().parent
|
|
320
|
+
cfg = read_yaml(config_file)
|
|
321
|
+
|
|
322
|
+
session_cfg = cfg.get("session", {}) or {}
|
|
323
|
+
session_id = session_cfg.get("id", "masknet_tutorial")
|
|
324
|
+
artifact_root = Path(session_cfg.get("artifact_root", "nextrec_logs"))
|
|
325
|
+
session_dir = Path(cfg.get("checkpoint_path") or (artifact_root / session_id))
|
|
326
|
+
setup_logger(session_id=session_id)
|
|
327
|
+
|
|
328
|
+
processor_path = Path(session_dir / "processor.pkl")
|
|
329
|
+
if not processor_path.exists():
|
|
330
|
+
processor_path = session_dir / "processor" / "processor.pkl"
|
|
331
|
+
|
|
332
|
+
predict_cfg = cfg.get("predict", {}) or {}
|
|
333
|
+
model_cfg_path = resolve_path(
|
|
334
|
+
cfg.get("model_config", "model_config.yaml"), config_dir
|
|
335
|
+
)
|
|
336
|
+
# feature_cfg_path = resolve_path(
|
|
337
|
+
# cfg.get("feature_config", "feature_config.yaml"), config_dir
|
|
338
|
+
# )
|
|
339
|
+
|
|
340
|
+
model_cfg = read_yaml(model_cfg_path)
|
|
341
|
+
# feature_cfg = read_yaml(feature_cfg_path)
|
|
342
|
+
model_cfg.setdefault("session_id", session_id)
|
|
343
|
+
model_cfg.setdefault("params", {})
|
|
344
|
+
|
|
345
|
+
processor = DataProcessor.load(processor_path)
|
|
346
|
+
|
|
347
|
+
# Load checkpoint and ensure required parameters are passed
|
|
348
|
+
checkpoint_base = Path(session_dir)
|
|
349
|
+
if checkpoint_base.is_dir():
|
|
350
|
+
candidates = sorted(checkpoint_base.glob("*.model"))
|
|
351
|
+
if not candidates:
|
|
352
|
+
raise FileNotFoundError(
|
|
353
|
+
f"[NextRec CLI Error]: Unable to find model checkpoint: {checkpoint_base}"
|
|
354
|
+
)
|
|
355
|
+
model_file = candidates[-1]
|
|
356
|
+
config_dir_for_features = checkpoint_base
|
|
357
|
+
else:
|
|
358
|
+
model_file = (
|
|
359
|
+
checkpoint_base.with_suffix(".model")
|
|
360
|
+
if checkpoint_base.suffix == ""
|
|
361
|
+
else checkpoint_base
|
|
362
|
+
)
|
|
363
|
+
config_dir_for_features = model_file.parent
|
|
364
|
+
|
|
365
|
+
features_config_path = config_dir_for_features / "features_config.pkl"
|
|
366
|
+
if not features_config_path.exists():
|
|
367
|
+
raise FileNotFoundError(
|
|
368
|
+
f"[NextRec CLI Error]: Unable to find features_config.pkl: {features_config_path}"
|
|
369
|
+
)
|
|
370
|
+
with open(features_config_path, "rb") as f:
|
|
371
|
+
features_config = pickle.load(f)
|
|
372
|
+
|
|
373
|
+
all_features = features_config.get("all_features", [])
|
|
374
|
+
target_cols = features_config.get("target", [])
|
|
375
|
+
id_columns = features_config.get("id_columns", [])
|
|
376
|
+
|
|
377
|
+
dense_features = [f for f in all_features if isinstance(f, DenseFeature)]
|
|
378
|
+
sparse_features = [f for f in all_features if isinstance(f, SparseFeature)]
|
|
379
|
+
sequence_features = [f for f in all_features if isinstance(f, SequenceFeature)]
|
|
380
|
+
|
|
381
|
+
target_override = (
|
|
382
|
+
cfg.get("targets")
|
|
383
|
+
or model_cfg.get("targets")
|
|
384
|
+
or model_cfg.get("params", {}).get("targets")
|
|
385
|
+
or model_cfg.get("params", {}).get("target")
|
|
386
|
+
)
|
|
387
|
+
if target_override:
|
|
388
|
+
target_cols = normalize_to_list(target_override)
|
|
389
|
+
|
|
390
|
+
model = build_model_instance(
|
|
391
|
+
model_cfg=model_cfg,
|
|
392
|
+
model_cfg_path=model_cfg_path,
|
|
393
|
+
dense_features=dense_features,
|
|
394
|
+
sparse_features=sparse_features,
|
|
395
|
+
sequence_features=sequence_features,
|
|
396
|
+
target=target_cols,
|
|
397
|
+
device=predict_cfg.get("device", "cpu"),
|
|
398
|
+
)
|
|
399
|
+
model.id_columns = id_columns
|
|
400
|
+
model.load_model(
|
|
401
|
+
model_file, map_location=predict_cfg.get("device", "cpu"), verbose=True
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
id_columns = []
|
|
405
|
+
if predict_cfg.get("id_column"):
|
|
406
|
+
id_columns = [predict_cfg["id_column"]]
|
|
407
|
+
model.id_columns = id_columns
|
|
408
|
+
|
|
409
|
+
rec_dataloader = RecDataLoader(
|
|
410
|
+
dense_features=model.dense_features,
|
|
411
|
+
sparse_features=model.sparse_features,
|
|
412
|
+
sequence_features=model.sequence_features,
|
|
413
|
+
target=None,
|
|
414
|
+
id_columns=id_columns or model.id_columns,
|
|
415
|
+
processor=processor,
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
data_path = resolve_path(predict_cfg["data_path"], config_dir)
|
|
419
|
+
batch_size = predict_cfg.get("batch_size", 512)
|
|
420
|
+
|
|
421
|
+
pred_loader = rec_dataloader.create_dataloader(
|
|
422
|
+
data=str(data_path),
|
|
423
|
+
batch_size=batch_size,
|
|
424
|
+
shuffle=False,
|
|
425
|
+
load_full=predict_cfg.get("load_full", False),
|
|
426
|
+
chunk_size=predict_cfg.get("chunk_size", 20000),
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
output_path = resolve_path(predict_cfg["output_path"], config_dir)
|
|
430
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
431
|
+
|
|
432
|
+
start = time.time()
|
|
433
|
+
model.predict(
|
|
434
|
+
data=pred_loader,
|
|
435
|
+
batch_size=batch_size,
|
|
436
|
+
include_ids=bool(id_columns),
|
|
437
|
+
return_dataframe=False,
|
|
438
|
+
save_path=output_path,
|
|
439
|
+
save_format=predict_cfg.get("save_format", "csv"),
|
|
440
|
+
num_workers=predict_cfg.get("num_workers", 0),
|
|
441
|
+
)
|
|
442
|
+
duration = time.time() - start
|
|
443
|
+
logger.info(f"Prediction completed, results saved to: {output_path}")
|
|
444
|
+
logger.info(f"Total time: {duration:.2f} seconds")
|
|
445
|
+
|
|
446
|
+
preview_rows = predict_cfg.get("preview_rows", 0)
|
|
447
|
+
if preview_rows > 0:
|
|
448
|
+
try:
|
|
449
|
+
preview = pd.read_csv(output_path, nrows=preview_rows, low_memory=False)
|
|
450
|
+
logger.info(f"Output preview:\n{preview}")
|
|
451
|
+
except Exception as exc: # pragma: no cover
|
|
452
|
+
logger.warning(f"Failed to read output preview: {exc}")
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
def main() -> None:
|
|
456
|
+
"""Parse CLI arguments and dispatch to train or predict mode."""
|
|
457
|
+
parser = argparse.ArgumentParser(
|
|
458
|
+
description="NextRec: Training and Prediction Pipeline",
|
|
459
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
460
|
+
epilog="""
|
|
461
|
+
Examples:
|
|
462
|
+
# Train a model
|
|
463
|
+
nextrec --mode=train --train_config=configs/train_config.yaml
|
|
464
|
+
|
|
465
|
+
# Run prediction
|
|
466
|
+
nextrec --mode=predict --predict_config=configs/predict_config.yaml
|
|
467
|
+
""",
|
|
468
|
+
)
|
|
469
|
+
parser.add_argument(
|
|
470
|
+
"--mode",
|
|
471
|
+
choices=["train", "predict"],
|
|
472
|
+
required=True,
|
|
473
|
+
help="Running mode: train or predict",
|
|
474
|
+
)
|
|
475
|
+
parser.add_argument("--train_config", help="Training configuration file path")
|
|
476
|
+
parser.add_argument("--predict_config", help="Prediction configuration file path")
|
|
477
|
+
args = parser.parse_args()
|
|
478
|
+
|
|
479
|
+
if args.mode == "train":
|
|
480
|
+
config_path = args.train_config
|
|
481
|
+
if not config_path:
|
|
482
|
+
parser.error("[NextRec CLI Error] train mode requires --train_config")
|
|
483
|
+
train_model(config_path)
|
|
484
|
+
else:
|
|
485
|
+
config_path = args.predict_config
|
|
486
|
+
if not config_path:
|
|
487
|
+
parser.error("[NextRec CLI Error] predict mode requires --predict_config")
|
|
488
|
+
predict_model(config_path)
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
if __name__ == "__main__":
|
|
492
|
+
main()
|
nextrec/data/__init__.py
CHANGED
|
@@ -27,35 +27,29 @@ from nextrec.data import data_utils
|
|
|
27
27
|
|
|
28
28
|
__all__ = [
|
|
29
29
|
# Batch utilities
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
30
|
+
"collate_fn",
|
|
31
|
+
"batch_to_dict",
|
|
32
|
+
"stack_section",
|
|
34
33
|
# Data processing
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
34
|
+
"get_column_data",
|
|
35
|
+
"split_dict_random",
|
|
36
|
+
"build_eval_candidates",
|
|
37
|
+
"get_user_ids",
|
|
40
38
|
# File utilities
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
39
|
+
"resolve_file_paths",
|
|
40
|
+
"iter_file_chunks",
|
|
41
|
+
"read_table",
|
|
42
|
+
"load_dataframes",
|
|
43
|
+
"default_output_dir",
|
|
47
44
|
# DataLoader
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
45
|
+
"TensorDictDataset",
|
|
46
|
+
"FileDataset",
|
|
47
|
+
"RecDataLoader",
|
|
48
|
+
"build_tensors_from_data",
|
|
53
49
|
# Preprocessor
|
|
54
|
-
|
|
55
|
-
|
|
50
|
+
"DataProcessor",
|
|
56
51
|
# Features
|
|
57
|
-
|
|
58
|
-
|
|
52
|
+
"FeatureSet",
|
|
59
53
|
# Legacy module
|
|
60
|
-
|
|
54
|
+
"data_utils",
|
|
61
55
|
]
|
nextrec/data/batch_utils.py
CHANGED
|
@@ -9,16 +9,22 @@ import torch
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
from typing import Any, Mapping
|
|
11
11
|
|
|
12
|
+
|
|
12
13
|
def stack_section(batch: list[dict], section: str):
|
|
13
14
|
entries = [item.get(section) for item in batch if item.get(section) is not None]
|
|
14
15
|
if not entries:
|
|
15
16
|
return None
|
|
16
17
|
merged: dict = {}
|
|
17
18
|
for name in entries[0]: # type: ignore
|
|
18
|
-
tensors = [
|
|
19
|
+
tensors = [
|
|
20
|
+
item[section][name]
|
|
21
|
+
for item in batch
|
|
22
|
+
if item.get(section) is not None and name in item[section]
|
|
23
|
+
]
|
|
19
24
|
merged[name] = torch.stack(tensors, dim=0)
|
|
20
25
|
return merged
|
|
21
26
|
|
|
27
|
+
|
|
22
28
|
def collate_fn(batch):
|
|
23
29
|
"""
|
|
24
30
|
Collate a list of sample dicts into the unified batch format:
|
|
@@ -28,7 +34,7 @@ def collate_fn(batch):
|
|
|
28
34
|
"ids": {id_name: Tensor(B, ...)} or None,
|
|
29
35
|
}
|
|
30
36
|
Args: batch: List of samples from DataLoader
|
|
31
|
-
|
|
37
|
+
|
|
32
38
|
Returns: dict: Batched data in unified format
|
|
33
39
|
"""
|
|
34
40
|
if not batch:
|
|
@@ -72,7 +78,9 @@ def collate_fn(batch):
|
|
|
72
78
|
|
|
73
79
|
def batch_to_dict(batch_data: Any, include_ids: bool = True) -> dict:
|
|
74
80
|
if not (isinstance(batch_data, Mapping) and "features" in batch_data):
|
|
75
|
-
raise TypeError(
|
|
81
|
+
raise TypeError(
|
|
82
|
+
"[BaseModel-batch_to_dict Error] Batch data must be a dict with 'features' produced by the current DataLoader."
|
|
83
|
+
)
|
|
76
84
|
return {
|
|
77
85
|
"features": batch_data.get("features", {}),
|
|
78
86
|
"labels": batch_data.get("labels"),
|