nextrec 0.3.6__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. nextrec/__init__.py +1 -1
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -5
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/features.py +30 -22
  6. nextrec/basic/layers.py +244 -113
  7. nextrec/basic/loggers.py +62 -43
  8. nextrec/basic/metrics.py +268 -119
  9. nextrec/basic/model.py +1373 -443
  10. nextrec/basic/session.py +10 -3
  11. nextrec/cli.py +498 -0
  12. nextrec/data/__init__.py +19 -25
  13. nextrec/data/batch_utils.py +11 -3
  14. nextrec/data/data_processing.py +42 -24
  15. nextrec/data/data_utils.py +26 -15
  16. nextrec/data/dataloader.py +303 -96
  17. nextrec/data/preprocessor.py +320 -199
  18. nextrec/loss/listwise.py +17 -9
  19. nextrec/loss/loss_utils.py +7 -8
  20. nextrec/loss/pairwise.py +2 -0
  21. nextrec/loss/pointwise.py +30 -12
  22. nextrec/models/generative/hstu.py +106 -40
  23. nextrec/models/match/dssm.py +82 -69
  24. nextrec/models/match/dssm_v2.py +72 -58
  25. nextrec/models/match/mind.py +175 -108
  26. nextrec/models/match/sdm.py +104 -88
  27. nextrec/models/match/youtube_dnn.py +73 -60
  28. nextrec/models/multi_task/esmm.py +53 -39
  29. nextrec/models/multi_task/mmoe.py +70 -47
  30. nextrec/models/multi_task/ple.py +107 -50
  31. nextrec/models/multi_task/poso.py +121 -41
  32. nextrec/models/multi_task/share_bottom.py +54 -38
  33. nextrec/models/ranking/afm.py +172 -45
  34. nextrec/models/ranking/autoint.py +84 -61
  35. nextrec/models/ranking/dcn.py +59 -42
  36. nextrec/models/ranking/dcn_v2.py +64 -23
  37. nextrec/models/ranking/deepfm.py +36 -26
  38. nextrec/models/ranking/dien.py +158 -102
  39. nextrec/models/ranking/din.py +88 -60
  40. nextrec/models/ranking/fibinet.py +55 -35
  41. nextrec/models/ranking/fm.py +32 -26
  42. nextrec/models/ranking/masknet.py +95 -34
  43. nextrec/models/ranking/pnn.py +34 -31
  44. nextrec/models/ranking/widedeep.py +37 -29
  45. nextrec/models/ranking/xdeepfm.py +63 -41
  46. nextrec/utils/__init__.py +61 -32
  47. nextrec/utils/config.py +490 -0
  48. nextrec/utils/device.py +52 -12
  49. nextrec/utils/distributed.py +141 -0
  50. nextrec/utils/embedding.py +1 -0
  51. nextrec/utils/feature.py +1 -0
  52. nextrec/utils/file.py +32 -11
  53. nextrec/utils/initializer.py +61 -16
  54. nextrec/utils/optimizer.py +25 -9
  55. nextrec/utils/synthetic_data.py +531 -0
  56. nextrec/utils/tensor.py +24 -13
  57. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/METADATA +15 -5
  58. nextrec-0.4.2.dist-info/RECORD +69 -0
  59. nextrec-0.4.2.dist-info/entry_points.txt +2 -0
  60. nextrec-0.3.6.dist-info/RECORD +0 -64
  61. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
  62. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/session.py CHANGED
@@ -7,7 +7,7 @@ Author: Yang Zhou,zyaztec@gmail.com
7
7
  import os
8
8
  import tempfile
9
9
  from dataclasses import dataclass
10
- from datetime import datetime, timezone
10
+ from datetime import datetime
11
11
  from pathlib import Path
12
12
 
13
13
  __all__ = [
@@ -16,6 +16,7 @@ __all__ = [
16
16
  "create_session",
17
17
  ]
18
18
 
19
+
19
20
  @dataclass(frozen=True)
20
21
  class Session:
21
22
  """Encapsulate standard folders for a NextRec experiment."""
@@ -35,7 +36,7 @@ class Session:
35
36
  @property
36
37
  def predictions_dir(self) -> Path:
37
38
  return self._ensure_dir(self.root / "predictions")
38
-
39
+
39
40
  @property
40
41
  def processor_dir(self) -> Path:
41
42
  return self._ensure_dir(self.root / "processor")
@@ -60,6 +61,7 @@ class Session:
60
61
  path.mkdir(parents=True, exist_ok=True)
61
62
  return path
62
63
 
64
+
63
65
  def create_session(experiment_id: str | Path | None = None) -> Session:
64
66
 
65
67
  if experiment_id is not None and str(experiment_id).strip():
@@ -86,6 +88,7 @@ def create_session(experiment_id: str | Path | None = None) -> Session:
86
88
 
87
89
  return Session(experiment_id=exp_id, root=root, log_basename=log_basename)
88
90
 
91
+
89
92
  def resolve_save_path(
90
93
  path: str | os.PathLike | Path | None,
91
94
  default_dir: str | Path,
@@ -129,7 +132,11 @@ def resolve_save_path(
129
132
  base_dir = candidate
130
133
  file_stem = default_name
131
134
  else:
132
- base_dir = candidate.parent if candidate.parent not in (Path("."), Path("")) else base_dir
135
+ base_dir = (
136
+ candidate.parent
137
+ if candidate.parent not in (Path("."), Path(""))
138
+ else base_dir
139
+ )
133
140
  file_stem = candidate.name or default_name
134
141
  else:
135
142
  file_stem = default_name
nextrec/cli.py ADDED
@@ -0,0 +1,498 @@
1
+ """
2
+ Command-line interface for NextRec training and prediction.
3
+
4
+
5
+ NextRec supports a flexible training and prediction pipeline driven by configuration files.
6
+ After preparing the configuration YAML files for training and prediction, users can run the
7
+ following script to execute the desired operations.
8
+
9
+ Examples:
10
+ # Train a model
11
+ nextrec --mode=train --train_config=tutorials/iflytek/scripts/masknet/train_config.yaml
12
+
13
+ # Run prediction
14
+ nextrec --mode=predict --predict_config=tutorials/iflytek/scripts/masknet/predict_config.yaml
15
+
16
+ Date: create on 06/12/2025
17
+ Author: Yang Zhou, zyaztec@gmail.com
18
+ """
19
+
20
+ import argparse
21
+ import logging
22
+ import pickle
23
+ import time
24
+ from pathlib import Path
25
+ from typing import Any, Dict, List
26
+
27
+ import pandas as pd
28
+
29
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
30
+ from nextrec.data.data_utils import split_dict_random
31
+ from nextrec.data.dataloader import RecDataLoader
32
+ from nextrec.data.preprocessor import DataProcessor
33
+ from nextrec.utils.config import (
34
+ build_feature_objects,
35
+ build_model_instance,
36
+ extract_feature_groups,
37
+ register_processor_features,
38
+ resolve_path,
39
+ select_features,
40
+ )
41
+ from nextrec.utils.feature import normalize_to_list
42
+ from nextrec.utils.file import (
43
+ iter_file_chunks,
44
+ read_table,
45
+ read_yaml,
46
+ resolve_file_paths,
47
+ )
48
+ from nextrec.basic.loggers import setup_logger
49
+
50
+ logger = logging.getLogger(__name__)
51
+
52
+
53
+ def train_model(train_config_path: str) -> None:
54
+ """
55
+ Train a NextRec model using the provided configuration file.
56
+
57
+ configuration file must specify the below sections:
58
+ - session: Session settings including id and artifact root
59
+ - data: Data settings including path, format, target, validation split
60
+ - dataloader: DataLoader settings including batch sizes and shuffling
61
+ - model_config: Path to the model configuration YAML file
62
+ - feature_config: Path to the feature configuration YAML file
63
+ - train: Training settings including optimizer, loss, metrics, epochs, etc.
64
+ """
65
+ config_file = Path(train_config_path)
66
+ config_dir = config_file.resolve().parent
67
+ cfg = read_yaml(config_file)
68
+
69
+ # read session configuration
70
+ session_cfg = cfg.get("session", {}) or {}
71
+ session_id = session_cfg.get("id", "nextrec_cli_session")
72
+ artifact_root = Path(session_cfg.get("artifact_root", "nextrec_logs"))
73
+ session_dir = artifact_root / session_id
74
+ setup_logger(session_id=session_id)
75
+
76
+ processor_path = session_dir / "processor.pkl"
77
+ processor_path = Path(processor_path)
78
+ processor_path.parent.mkdir(parents=True, exist_ok=True)
79
+
80
+ data_cfg = cfg.get("data", {}) or {}
81
+ dataloader_cfg = cfg.get("dataloader", {}) or {}
82
+ streaming = bool(data_cfg.get("streaming", False))
83
+ dataloader_chunk_size = dataloader_cfg.get("chunk_size", 20000)
84
+
85
+ # train data
86
+ data_path = resolve_path(data_cfg["path"], config_dir)
87
+ target = normalize_to_list(data_cfg["target"])
88
+ file_paths: List[str] = []
89
+ file_type: str | None = None
90
+ streaming_train_files: List[str] | None = None
91
+ streaming_valid_files: List[str] | None = None
92
+
93
+ feature_cfg_path = resolve_path(
94
+ cfg.get("feature_config", "feature_config.yaml"), config_dir
95
+ )
96
+ model_cfg_path = resolve_path(
97
+ cfg.get("model_config", "model_config.yaml"), config_dir
98
+ )
99
+
100
+ feature_cfg = read_yaml(feature_cfg_path)
101
+ model_cfg = read_yaml(model_cfg_path)
102
+
103
+ if streaming:
104
+ file_paths, file_type = resolve_file_paths(str(data_path))
105
+ first_file = file_paths[0]
106
+ first_chunk_size = max(1, min(dataloader_chunk_size, 1000))
107
+ chunk_iter = iter_file_chunks(first_file, file_type, first_chunk_size)
108
+ try:
109
+ first_chunk = next(chunk_iter)
110
+ except StopIteration as exc:
111
+ raise ValueError(f"Data file is empty: {first_file}") from exc
112
+ df_columns = list(first_chunk.columns)
113
+
114
+ else:
115
+ df = read_table(data_path, data_cfg.get("format"))
116
+ df_columns = list(df.columns)
117
+
118
+ # for some models have independent feature groups, we need to extract them here
119
+ feature_groups, grouped_columns = extract_feature_groups(feature_cfg, df_columns)
120
+ if feature_groups:
121
+ model_cfg.setdefault("params", {})
122
+ model_cfg["params"].setdefault("feature_groups", feature_groups)
123
+
124
+ dense_names, sparse_names, sequence_names = select_features(feature_cfg, df_columns)
125
+ used_columns = (
126
+ dense_names + sparse_names + sequence_names + grouped_columns + target
127
+ )
128
+
129
+ # keep order but drop duplicates
130
+ seen = set()
131
+ unique_used_columns = []
132
+ for col in used_columns:
133
+ if col not in seen:
134
+ unique_used_columns.append(col)
135
+ seen.add(col)
136
+
137
+ processor = DataProcessor()
138
+ register_processor_features(
139
+ processor, feature_cfg, dense_names, sparse_names, sequence_names
140
+ )
141
+
142
+ if streaming:
143
+ processor.fit(str(data_path), chunk_size=dataloader_chunk_size)
144
+ processed = None
145
+ df = None # type: ignore[assignment]
146
+ else:
147
+ df = df[unique_used_columns]
148
+ processor.fit(df)
149
+ processed = processor.transform(df, return_dict=True)
150
+
151
+ processor.save(processor_path)
152
+ dense_features, sparse_features, sequence_features = build_feature_objects(
153
+ processor,
154
+ feature_cfg,
155
+ dense_names,
156
+ sparse_names,
157
+ sequence_names,
158
+ )
159
+
160
+ # Check if validation dataset path is specified
161
+ val_data_path = data_cfg.get("val_path") or data_cfg.get("valid_path")
162
+ if streaming:
163
+ if not file_paths:
164
+ file_paths, file_type = resolve_file_paths(str(data_path))
165
+ streaming_train_files = file_paths
166
+ streaming_valid_ratio = data_cfg.get("valid_ratio")
167
+ if val_data_path:
168
+ streaming_valid_files = None
169
+ elif streaming_valid_ratio is not None:
170
+ ratio = float(streaming_valid_ratio)
171
+ if not (0 < ratio < 1):
172
+ raise ValueError(
173
+ f"[NextRec CLI Error] Valid_ratio must be between 0 and 1, current value is {streaming_valid_ratio}"
174
+ )
175
+ total_files = len(file_paths)
176
+ if total_files < 2:
177
+ raise ValueError(
178
+ "[NextRec CLI Error] Must provide val_path or increase the number of data files. At least 2 files are required for streaming validation split."
179
+ )
180
+ val_count = max(1, int(round(total_files * ratio)))
181
+ if val_count >= total_files:
182
+ val_count = total_files - 1
183
+ streaming_valid_files = file_paths[-val_count:]
184
+ streaming_train_files = file_paths[:-val_count]
185
+ logger.info(
186
+ "使用 valid_ratio=%.3f 切分文件: 训练 %d 个文件, 验证 %d 个文件",
187
+ ratio,
188
+ len(streaming_train_files),
189
+ len(streaming_valid_files),
190
+ )
191
+ train_data: Dict[str, Any]
192
+ valid_data: Dict[str, Any] | None
193
+
194
+ if val_data_path and not streaming:
195
+ # Use specified validation dataset path
196
+ logger.info("使用指定的验证集路径: %s", val_data_path)
197
+ val_data_resolved = resolve_path(val_data_path, config_dir)
198
+ val_df = read_table(val_data_resolved, data_cfg.get("format"))
199
+ val_df = val_df[unique_used_columns]
200
+ if not isinstance(processed, dict):
201
+ raise TypeError("Processed data must be a dictionary")
202
+ train_data = processed
203
+ valid_data_result = processor.transform(val_df, return_dict=True)
204
+ if not isinstance(valid_data_result, dict):
205
+ raise TypeError("Validation data must be a dictionary")
206
+ valid_data = valid_data_result
207
+ train_size = len(list(train_data.values())[0])
208
+ valid_size = len(list(valid_data.values())[0])
209
+ logger.info("训练集样本数: %s, 验证集样本数: %s", train_size, valid_size)
210
+ elif streaming:
211
+ train_data = None # type: ignore[assignment]
212
+ valid_data = None
213
+ if not val_data_path and not streaming_valid_files:
214
+ logger.info(
215
+ "流式训练模式,未指定验证集路径且未配置 valid_ratio,跳过验证集创建"
216
+ )
217
+ else:
218
+ # Split data using valid_ratio
219
+ logger.info("使用 valid_ratio 切分数据: %s", data_cfg.get("valid_ratio", 0.2))
220
+ if not isinstance(processed, dict):
221
+ raise TypeError("Processed data must be a dictionary for splitting")
222
+ train_data, valid_data = split_dict_random(
223
+ processed,
224
+ test_size=data_cfg.get("valid_ratio", 0.2),
225
+ random_state=data_cfg.get("random_state", 2024),
226
+ )
227
+
228
+ dataloader = RecDataLoader(
229
+ dense_features=dense_features,
230
+ sparse_features=sparse_features,
231
+ sequence_features=sequence_features,
232
+ target=target,
233
+ processor=processor if streaming else None,
234
+ )
235
+ if streaming:
236
+ train_stream_source = streaming_train_files or file_paths
237
+ train_loader = dataloader.create_dataloader(
238
+ data=train_stream_source,
239
+ batch_size=dataloader_cfg.get("train_batch_size", 512),
240
+ shuffle=dataloader_cfg.get("train_shuffle", True),
241
+ load_full=False,
242
+ chunk_size=dataloader_chunk_size,
243
+ )
244
+ valid_loader = None
245
+ if val_data_path:
246
+ val_data_resolved = resolve_path(val_data_path, config_dir)
247
+ valid_loader = dataloader.create_dataloader(
248
+ data=str(val_data_resolved),
249
+ batch_size=dataloader_cfg.get("valid_batch_size", 512),
250
+ shuffle=dataloader_cfg.get("valid_shuffle", False),
251
+ load_full=False,
252
+ chunk_size=dataloader_chunk_size,
253
+ )
254
+ elif streaming_valid_files:
255
+ valid_loader = dataloader.create_dataloader(
256
+ data=streaming_valid_files,
257
+ batch_size=dataloader_cfg.get("valid_batch_size", 512),
258
+ shuffle=dataloader_cfg.get("valid_shuffle", False),
259
+ load_full=False,
260
+ chunk_size=dataloader_chunk_size,
261
+ )
262
+ else:
263
+ train_loader = dataloader.create_dataloader(
264
+ data=train_data,
265
+ batch_size=dataloader_cfg.get("train_batch_size", 512),
266
+ shuffle=dataloader_cfg.get("train_shuffle", True),
267
+ )
268
+ valid_loader = dataloader.create_dataloader(
269
+ data=valid_data,
270
+ batch_size=dataloader_cfg.get("valid_batch_size", 512),
271
+ shuffle=dataloader_cfg.get("valid_shuffle", False),
272
+ )
273
+
274
+ model_cfg.setdefault("session_id", session_id)
275
+ train_cfg = cfg.get("train", {}) or {}
276
+ device = train_cfg.get("device", model_cfg.get("device", "cpu"))
277
+ model = build_model_instance(
278
+ model_cfg,
279
+ model_cfg_path,
280
+ dense_features,
281
+ sparse_features,
282
+ sequence_features,
283
+ target,
284
+ device,
285
+ )
286
+
287
+ model.compile(
288
+ optimizer=train_cfg.get("optimizer", "adam"),
289
+ optimizer_params=train_cfg.get("optimizer_params", {}),
290
+ loss=train_cfg.get("loss", "focal"),
291
+ loss_params=train_cfg.get("loss_params", {}),
292
+ )
293
+
294
+ model.fit(
295
+ train_data=train_loader,
296
+ valid_data=valid_loader,
297
+ metrics=train_cfg.get("metrics", ["auc", "recall", "precision"]),
298
+ epochs=train_cfg.get("epochs", 1),
299
+ batch_size=train_cfg.get(
300
+ "batch_size", dataloader_cfg.get("train_batch_size", 512)
301
+ ),
302
+ shuffle=train_cfg.get("shuffle", True),
303
+ )
304
+
305
+
306
+ def predict_model(predict_config_path: str) -> None:
307
+ """
308
+ Run prediction using a trained model and configuration file.
309
+ """
310
+ config_file = Path(predict_config_path)
311
+ config_dir = config_file.resolve().parent
312
+ cfg = read_yaml(config_file)
313
+
314
+ session_cfg = cfg.get("session", {}) or {}
315
+ session_id = session_cfg.get("id", "masknet_tutorial")
316
+ artifact_root = Path(session_cfg.get("artifact_root", "nextrec_logs"))
317
+ session_dir = Path(cfg.get("checkpoint_path") or (artifact_root / session_id))
318
+ setup_logger(session_id=session_id)
319
+
320
+ processor_path = Path(session_dir / "processor.pkl")
321
+ if not processor_path.exists():
322
+ processor_path = session_dir / "processor" / "processor.pkl"
323
+
324
+ predict_cfg = cfg.get("predict", {}) or {}
325
+ model_cfg_path = resolve_path(
326
+ cfg.get("model_config", "model_config.yaml"), config_dir
327
+ )
328
+ feature_cfg_path = resolve_path(
329
+ cfg.get("feature_config", "feature_config.yaml"), config_dir
330
+ )
331
+
332
+ model_cfg = read_yaml(model_cfg_path)
333
+ feature_cfg = read_yaml(feature_cfg_path)
334
+ model_cfg.setdefault("session_id", session_id)
335
+ feature_groups_raw = feature_cfg.get("feature_groups") or {}
336
+ model_cfg.setdefault("params", {})
337
+
338
+ # attach feature_groups in predict phase to avoid missing bindings
339
+ model_cfg["params"]["feature_groups"] = feature_groups_raw
340
+
341
+ processor = DataProcessor.load(processor_path)
342
+
343
+ # Load checkpoint and ensure required parameters are passed
344
+ checkpoint_base = Path(session_dir)
345
+ if checkpoint_base.is_dir():
346
+ candidates = sorted(checkpoint_base.glob("*.model"))
347
+ if not candidates:
348
+ raise FileNotFoundError(
349
+ f"[NextRec CLI Error]: Unable to find model checkpoint: {checkpoint_base}"
350
+ )
351
+ model_file = candidates[-1]
352
+ config_dir_for_features = checkpoint_base
353
+ else:
354
+ model_file = (
355
+ checkpoint_base.with_suffix(".model")
356
+ if checkpoint_base.suffix == ""
357
+ else checkpoint_base
358
+ )
359
+ config_dir_for_features = model_file.parent
360
+
361
+ features_config_path = config_dir_for_features / "features_config.pkl"
362
+ if not features_config_path.exists():
363
+ raise FileNotFoundError(
364
+ f"[NextRec CLI Error]: Unable to find features_config.pkl: {features_config_path}"
365
+ )
366
+ with open(features_config_path, "rb") as f:
367
+ features_config = pickle.load(f)
368
+
369
+ all_features = features_config.get("all_features", [])
370
+ target_cols = features_config.get("target", [])
371
+ id_columns = features_config.get("id_columns", [])
372
+
373
+ dense_features = [f for f in all_features if isinstance(f, DenseFeature)]
374
+ sparse_features = [f for f in all_features if isinstance(f, SparseFeature)]
375
+ sequence_features = [f for f in all_features if isinstance(f, SequenceFeature)]
376
+
377
+ target_override = (
378
+ cfg.get("targets")
379
+ or model_cfg.get("targets")
380
+ or model_cfg.get("params", {}).get("targets")
381
+ or model_cfg.get("params", {}).get("target")
382
+ )
383
+ if target_override:
384
+ target_cols = normalize_to_list(target_override)
385
+
386
+ # Recompute feature_groups with available feature names to drive bindings
387
+ feature_group_names = [f.name for f in all_features if hasattr(f, "name")]
388
+ parsed_feature_groups, _ = extract_feature_groups(feature_cfg, feature_group_names)
389
+ if parsed_feature_groups:
390
+ model_cfg.setdefault("params", {})
391
+ model_cfg["params"]["feature_groups"] = parsed_feature_groups
392
+
393
+ model = build_model_instance(
394
+ model_cfg=model_cfg,
395
+ model_cfg_path=model_cfg_path,
396
+ dense_features=dense_features,
397
+ sparse_features=sparse_features,
398
+ sequence_features=sequence_features,
399
+ target=target_cols,
400
+ device=predict_cfg.get("device", "cpu"),
401
+ )
402
+ model.id_columns = id_columns
403
+ model.load_model(
404
+ model_file, map_location=predict_cfg.get("device", "cpu"), verbose=True
405
+ )
406
+
407
+ id_columns = []
408
+ if predict_cfg.get("id_column"):
409
+ id_columns = [predict_cfg["id_column"]]
410
+ model.id_columns = id_columns
411
+
412
+ rec_dataloader = RecDataLoader(
413
+ dense_features=model.dense_features,
414
+ sparse_features=model.sparse_features,
415
+ sequence_features=model.sequence_features,
416
+ target=None,
417
+ id_columns=id_columns or model.id_columns,
418
+ processor=processor,
419
+ )
420
+
421
+ data_path = resolve_path(predict_cfg["data_path"], config_dir)
422
+ batch_size = predict_cfg.get("batch_size", 512)
423
+
424
+ pred_loader = rec_dataloader.create_dataloader(
425
+ data=str(data_path),
426
+ batch_size=batch_size,
427
+ shuffle=False,
428
+ load_full=predict_cfg.get("load_full", False),
429
+ chunk_size=predict_cfg.get("chunk_size", 20000),
430
+ )
431
+
432
+ output_path = resolve_path(predict_cfg["output_path"], config_dir)
433
+ output_path.parent.mkdir(parents=True, exist_ok=True)
434
+
435
+ start = time.time()
436
+ model.predict(
437
+ data=pred_loader,
438
+ batch_size=batch_size,
439
+ include_ids=bool(id_columns),
440
+ return_dataframe=False,
441
+ save_path=output_path,
442
+ save_format=predict_cfg.get("save_format", "csv"),
443
+ )
444
+ duration = time.time() - start
445
+ logger.info(f"Prediction completed, results saved to: {output_path}")
446
+ logger.info(f"Total time: {duration:.2f} seconds")
447
+
448
+ preview_rows = predict_cfg.get("preview_rows", 0)
449
+ if preview_rows > 0:
450
+ try:
451
+ preview = pd.read_csv(output_path, nrows=preview_rows)
452
+ logger.info(f"Output preview:\n{preview}")
453
+ except Exception as exc: # pragma: no cover
454
+ logger.warning(f"Failed to read output preview: {exc}")
455
+
456
+
457
+ def main() -> None:
458
+ """Parse CLI arguments and dispatch to train or predict mode."""
459
+ parser = argparse.ArgumentParser(
460
+ description="NextRec: Training and Prediction Pipeline",
461
+ formatter_class=argparse.RawDescriptionHelpFormatter,
462
+ epilog="""
463
+ Examples:
464
+ # Train a model
465
+ nextrec --mode=train --train_config=configs/train_config.yaml
466
+
467
+ # Run prediction
468
+ nextrec --mode=predict --predict_config=configs/predict_config.yaml
469
+ """,
470
+ )
471
+ parser.add_argument(
472
+ "--mode",
473
+ choices=["train", "predict"],
474
+ required=True,
475
+ help="运行模式:train 或 predict",
476
+ )
477
+ parser.add_argument("--train_config", help="训练配置文件路径")
478
+ parser.add_argument("--predict_config", help="预测配置文件路径")
479
+ parser.add_argument(
480
+ "--config",
481
+ help="通用配置文件路径(已废弃,建议使用 --train_config 或 --predict_config)",
482
+ )
483
+ args = parser.parse_args()
484
+
485
+ if args.mode == "train":
486
+ config_path = args.train_config or args.config
487
+ if not config_path:
488
+ parser.error("train 模式需要提供 --train_config")
489
+ train_model(config_path)
490
+ else:
491
+ config_path = args.predict_config or args.config
492
+ if not config_path:
493
+ parser.error("predict 模式需要提供 --predict_config")
494
+ predict_model(config_path)
495
+
496
+
497
+ if __name__ == "__main__":
498
+ main()
nextrec/data/__init__.py CHANGED
@@ -27,35 +27,29 @@ from nextrec.data import data_utils
27
27
 
28
28
  __all__ = [
29
29
  # Batch utilities
30
- 'collate_fn',
31
- 'batch_to_dict',
32
- 'stack_section',
33
-
30
+ "collate_fn",
31
+ "batch_to_dict",
32
+ "stack_section",
34
33
  # Data processing
35
- 'get_column_data',
36
- 'split_dict_random',
37
- 'build_eval_candidates',
38
- 'get_user_ids',
39
-
34
+ "get_column_data",
35
+ "split_dict_random",
36
+ "build_eval_candidates",
37
+ "get_user_ids",
40
38
  # File utilities
41
- 'resolve_file_paths',
42
- 'iter_file_chunks',
43
- 'read_table',
44
- 'load_dataframes',
45
- 'default_output_dir',
46
-
39
+ "resolve_file_paths",
40
+ "iter_file_chunks",
41
+ "read_table",
42
+ "load_dataframes",
43
+ "default_output_dir",
47
44
  # DataLoader
48
- 'TensorDictDataset',
49
- 'FileDataset',
50
- 'RecDataLoader',
51
- 'build_tensors_from_data',
52
-
45
+ "TensorDictDataset",
46
+ "FileDataset",
47
+ "RecDataLoader",
48
+ "build_tensors_from_data",
53
49
  # Preprocessor
54
- 'DataProcessor',
55
-
50
+ "DataProcessor",
56
51
  # Features
57
- 'FeatureSet',
58
-
52
+ "FeatureSet",
59
53
  # Legacy module
60
- 'data_utils',
54
+ "data_utils",
61
55
  ]
@@ -9,16 +9,22 @@ import torch
9
9
  import numpy as np
10
10
  from typing import Any, Mapping
11
11
 
12
+
12
13
  def stack_section(batch: list[dict], section: str):
13
14
  entries = [item.get(section) for item in batch if item.get(section) is not None]
14
15
  if not entries:
15
16
  return None
16
17
  merged: dict = {}
17
18
  for name in entries[0]: # type: ignore
18
- tensors = [item[section][name] for item in batch if item.get(section) is not None and name in item[section]]
19
+ tensors = [
20
+ item[section][name]
21
+ for item in batch
22
+ if item.get(section) is not None and name in item[section]
23
+ ]
19
24
  merged[name] = torch.stack(tensors, dim=0)
20
25
  return merged
21
26
 
27
+
22
28
  def collate_fn(batch):
23
29
  """
24
30
  Collate a list of sample dicts into the unified batch format:
@@ -28,7 +34,7 @@ def collate_fn(batch):
28
34
  "ids": {id_name: Tensor(B, ...)} or None,
29
35
  }
30
36
  Args: batch: List of samples from DataLoader
31
-
37
+
32
38
  Returns: dict: Batched data in unified format
33
39
  """
34
40
  if not batch:
@@ -72,7 +78,9 @@ def collate_fn(batch):
72
78
 
73
79
  def batch_to_dict(batch_data: Any, include_ids: bool = True) -> dict:
74
80
  if not (isinstance(batch_data, Mapping) and "features" in batch_data):
75
- raise TypeError("[BaseModel-batch_to_dict Error] Batch data must be a dict with 'features' produced by the current DataLoader.")
81
+ raise TypeError(
82
+ "[BaseModel-batch_to_dict Error] Batch data must be a dict with 'features' produced by the current DataLoader."
83
+ )
76
84
  return {
77
85
  "features": batch_data.get("features", {}),
78
86
  "labels": batch_data.get("labels"),