nextrec 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. nextrec/__init__.py +1 -1
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -5
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/features.py +30 -22
  6. nextrec/basic/layers.py +250 -112
  7. nextrec/basic/loggers.py +63 -44
  8. nextrec/basic/metrics.py +270 -120
  9. nextrec/basic/model.py +1084 -402
  10. nextrec/basic/session.py +10 -3
  11. nextrec/cli.py +492 -0
  12. nextrec/data/__init__.py +19 -25
  13. nextrec/data/batch_utils.py +11 -3
  14. nextrec/data/data_processing.py +51 -45
  15. nextrec/data/data_utils.py +26 -15
  16. nextrec/data/dataloader.py +273 -96
  17. nextrec/data/preprocessor.py +320 -199
  18. nextrec/loss/listwise.py +17 -9
  19. nextrec/loss/loss_utils.py +7 -8
  20. nextrec/loss/pairwise.py +2 -0
  21. nextrec/loss/pointwise.py +30 -12
  22. nextrec/models/generative/hstu.py +103 -38
  23. nextrec/models/match/dssm.py +82 -68
  24. nextrec/models/match/dssm_v2.py +72 -57
  25. nextrec/models/match/mind.py +175 -107
  26. nextrec/models/match/sdm.py +104 -87
  27. nextrec/models/match/youtube_dnn.py +73 -59
  28. nextrec/models/multi_task/esmm.py +69 -46
  29. nextrec/models/multi_task/mmoe.py +91 -53
  30. nextrec/models/multi_task/ple.py +117 -58
  31. nextrec/models/multi_task/poso.py +163 -55
  32. nextrec/models/multi_task/share_bottom.py +63 -36
  33. nextrec/models/ranking/afm.py +80 -45
  34. nextrec/models/ranking/autoint.py +74 -57
  35. nextrec/models/ranking/dcn.py +110 -48
  36. nextrec/models/ranking/dcn_v2.py +265 -45
  37. nextrec/models/ranking/deepfm.py +39 -24
  38. nextrec/models/ranking/dien.py +335 -146
  39. nextrec/models/ranking/din.py +158 -92
  40. nextrec/models/ranking/fibinet.py +134 -52
  41. nextrec/models/ranking/fm.py +68 -26
  42. nextrec/models/ranking/masknet.py +95 -33
  43. nextrec/models/ranking/pnn.py +128 -58
  44. nextrec/models/ranking/widedeep.py +40 -28
  45. nextrec/models/ranking/xdeepfm.py +67 -40
  46. nextrec/utils/__init__.py +59 -34
  47. nextrec/utils/config.py +496 -0
  48. nextrec/utils/device.py +30 -20
  49. nextrec/utils/distributed.py +36 -9
  50. nextrec/utils/embedding.py +1 -0
  51. nextrec/utils/feature.py +1 -0
  52. nextrec/utils/file.py +33 -11
  53. nextrec/utils/initializer.py +61 -16
  54. nextrec/utils/model.py +22 -0
  55. nextrec/utils/optimizer.py +25 -9
  56. nextrec/utils/synthetic_data.py +283 -165
  57. nextrec/utils/tensor.py +24 -13
  58. {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/METADATA +53 -24
  59. nextrec-0.4.3.dist-info/RECORD +69 -0
  60. nextrec-0.4.3.dist-info/entry_points.txt +2 -0
  61. nextrec-0.4.1.dist-info/RECORD +0 -66
  62. {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/WHEEL +0 -0
  63. {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/session.py CHANGED
@@ -7,7 +7,7 @@ Author: Yang Zhou,zyaztec@gmail.com
7
7
  import os
8
8
  import tempfile
9
9
  from dataclasses import dataclass
10
- from datetime import datetime, timezone
10
+ from datetime import datetime
11
11
  from pathlib import Path
12
12
 
13
13
  __all__ = [
@@ -16,6 +16,7 @@ __all__ = [
16
16
  "create_session",
17
17
  ]
18
18
 
19
+
19
20
  @dataclass(frozen=True)
20
21
  class Session:
21
22
  """Encapsulate standard folders for a NextRec experiment."""
@@ -35,7 +36,7 @@ class Session:
35
36
  @property
36
37
  def predictions_dir(self) -> Path:
37
38
  return self._ensure_dir(self.root / "predictions")
38
-
39
+
39
40
  @property
40
41
  def processor_dir(self) -> Path:
41
42
  return self._ensure_dir(self.root / "processor")
@@ -60,6 +61,7 @@ class Session:
60
61
  path.mkdir(parents=True, exist_ok=True)
61
62
  return path
62
63
 
64
+
63
65
  def create_session(experiment_id: str | Path | None = None) -> Session:
64
66
 
65
67
  if experiment_id is not None and str(experiment_id).strip():
@@ -86,6 +88,7 @@ def create_session(experiment_id: str | Path | None = None) -> Session:
86
88
 
87
89
  return Session(experiment_id=exp_id, root=root, log_basename=log_basename)
88
90
 
91
+
89
92
  def resolve_save_path(
90
93
  path: str | os.PathLike | Path | None,
91
94
  default_dir: str | Path,
@@ -129,7 +132,11 @@ def resolve_save_path(
129
132
  base_dir = candidate
130
133
  file_stem = default_name
131
134
  else:
132
- base_dir = candidate.parent if candidate.parent not in (Path("."), Path("")) else base_dir
135
+ base_dir = (
136
+ candidate.parent
137
+ if candidate.parent not in (Path("."), Path(""))
138
+ else base_dir
139
+ )
133
140
  file_stem = candidate.name or default_name
134
141
  else:
135
142
  file_stem = default_name
nextrec/cli.py ADDED
@@ -0,0 +1,492 @@
1
+ """
2
+ Command-line interface for NextRec training and prediction.
3
+
4
+
5
+ NextRec supports a flexible training and prediction pipeline driven by configuration files.
6
+ After preparing the configuration YAML files for training and prediction, users can run the
7
+ following script to execute the desired operations.
8
+
9
+ Examples:
10
+ # Train a model
11
+ nextrec --mode=train --train_config=nextrec_cli_preset/train_config.yaml
12
+
13
+ # Run prediction
14
+ nextrec --mode=predict --predict_config=nextrec_cli_preset/predict_config.yaml
15
+
16
+ Date: create on 06/12/2025
17
+ Author: Yang Zhou, zyaztec@gmail.com
18
+ """
19
+
20
+ import argparse
21
+ import logging
22
+ import pickle
23
+ import time
24
+ from pathlib import Path
25
+ from typing import Any, Dict, List
26
+
27
+ import pandas as pd
28
+
29
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
30
+ from nextrec.data.data_utils import split_dict_random
31
+ from nextrec.data.dataloader import RecDataLoader
32
+ from nextrec.data.preprocessor import DataProcessor
33
+ from nextrec.utils.config import (
34
+ build_feature_objects,
35
+ build_model_instance,
36
+ register_processor_features,
37
+ resolve_path,
38
+ select_features,
39
+ )
40
+ from nextrec.utils.feature import normalize_to_list
41
+ from nextrec.utils.file import (
42
+ iter_file_chunks,
43
+ read_table,
44
+ read_yaml,
45
+ resolve_file_paths,
46
+ )
47
+ from nextrec.basic.loggers import setup_logger
48
+
49
+ logger = logging.getLogger(__name__)
50
+
51
+
52
+ def train_model(train_config_path: str) -> None:
53
+ """
54
+ Train a NextRec model using the provided configuration file.
55
+
56
+ configuration file must specify the below sections:
57
+ - session: Session settings including id and artifact root
58
+ - data: Data settings including path, format, target, validation split
59
+ - dataloader: DataLoader settings including batch sizes and shuffling
60
+ - model_config: Path to the model configuration YAML file
61
+ - feature_config: Path to the feature configuration YAML file
62
+ - train: Training settings including optimizer, loss, metrics, epochs, etc.
63
+ """
64
+ config_file = Path(train_config_path)
65
+ config_dir = config_file.resolve().parent
66
+ cfg = read_yaml(config_file)
67
+
68
+ # read session configuration
69
+ session_cfg = cfg.get("session", {}) or {}
70
+ session_id = session_cfg.get("id", "nextrec_cli_session")
71
+ artifact_root = Path(session_cfg.get("artifact_root", "nextrec_logs"))
72
+ session_dir = artifact_root / session_id
73
+ setup_logger(session_id=session_id)
74
+
75
+ processor_path = session_dir / "processor.pkl"
76
+ processor_path = Path(processor_path)
77
+ processor_path.parent.mkdir(parents=True, exist_ok=True)
78
+
79
+ data_cfg = cfg.get("data", {}) or {}
80
+ dataloader_cfg = cfg.get("dataloader", {}) or {}
81
+ streaming = bool(data_cfg.get("streaming", False))
82
+ dataloader_chunk_size = dataloader_cfg.get("chunk_size", 20000)
83
+
84
+ # train data
85
+ data_path = resolve_path(data_cfg["path"], config_dir)
86
+ target = normalize_to_list(data_cfg["target"])
87
+ file_paths: List[str] = []
88
+ file_type: str | None = None
89
+ streaming_train_files: List[str] | None = None
90
+ streaming_valid_files: List[str] | None = None
91
+
92
+ feature_cfg_path = resolve_path(
93
+ cfg.get("feature_config", "feature_config.yaml"), config_dir
94
+ )
95
+ model_cfg_path = resolve_path(
96
+ cfg.get("model_config", "model_config.yaml"), config_dir
97
+ )
98
+
99
+ feature_cfg = read_yaml(feature_cfg_path)
100
+ model_cfg = read_yaml(model_cfg_path)
101
+
102
+ if streaming:
103
+ file_paths, file_type = resolve_file_paths(str(data_path))
104
+ first_file = file_paths[0]
105
+ first_chunk_size = max(1, min(dataloader_chunk_size, 1000))
106
+ chunk_iter = iter_file_chunks(first_file, file_type, first_chunk_size)
107
+ try:
108
+ first_chunk = next(chunk_iter)
109
+ except StopIteration as exc:
110
+ raise ValueError(f"Data file is empty: {first_file}") from exc
111
+ df_columns = list(first_chunk.columns)
112
+
113
+ else:
114
+ df = read_table(data_path, data_cfg.get("format"))
115
+ df_columns = list(df.columns)
116
+
117
+ dense_names, sparse_names, sequence_names = select_features(feature_cfg, df_columns)
118
+
119
+ # Extract id_column from data config for GAUC metrics
120
+ id_column = data_cfg.get("id_column") or data_cfg.get("user_id_column")
121
+ id_columns = [id_column] if id_column else []
122
+
123
+ used_columns = dense_names + sparse_names + sequence_names + target + id_columns
124
+
125
+ # keep order but drop duplicates
126
+ seen = set()
127
+ unique_used_columns = []
128
+ for col in used_columns:
129
+ if col not in seen:
130
+ unique_used_columns.append(col)
131
+ seen.add(col)
132
+
133
+ processor = DataProcessor()
134
+ register_processor_features(
135
+ processor, feature_cfg, dense_names, sparse_names, sequence_names
136
+ )
137
+
138
+ if streaming:
139
+ processor.fit(str(data_path), chunk_size=dataloader_chunk_size)
140
+ processed = None
141
+ df = None # type: ignore[assignment]
142
+ else:
143
+ df = df[unique_used_columns]
144
+ processor.fit(df)
145
+ processed = processor.transform(df, return_dict=True)
146
+
147
+ processor.save(processor_path)
148
+ dense_features, sparse_features, sequence_features = build_feature_objects(
149
+ processor,
150
+ feature_cfg,
151
+ dense_names,
152
+ sparse_names,
153
+ sequence_names,
154
+ )
155
+
156
+ # Check if validation dataset path is specified
157
+ val_data_path = data_cfg.get("val_path") or data_cfg.get("valid_path")
158
+ if streaming:
159
+ if not file_paths:
160
+ file_paths, file_type = resolve_file_paths(str(data_path))
161
+ streaming_train_files = file_paths
162
+ streaming_valid_ratio = data_cfg.get("valid_ratio")
163
+ if val_data_path:
164
+ streaming_valid_files = None
165
+ elif streaming_valid_ratio is not None:
166
+ ratio = float(streaming_valid_ratio)
167
+ if not (0 < ratio < 1):
168
+ raise ValueError(
169
+ f"[NextRec CLI Error] Valid_ratio must be between 0 and 1, current value is {streaming_valid_ratio}"
170
+ )
171
+ total_files = len(file_paths)
172
+ if total_files < 2:
173
+ raise ValueError(
174
+ "[NextRec CLI Error] Must provide val_path or increase the number of data files. At least 2 files are required for streaming validation split."
175
+ )
176
+ val_count = max(1, int(round(total_files * ratio)))
177
+ if val_count >= total_files:
178
+ val_count = total_files - 1
179
+ streaming_valid_files = file_paths[-val_count:]
180
+ streaming_train_files = file_paths[:-val_count]
181
+ logger.info(
182
+ f"Split files for streaming training and validation using valid_ratio={ratio:.3f}: training {len(streaming_train_files)} files, validation {len(streaming_valid_files)} files"
183
+ )
184
+ train_data: Dict[str, Any]
185
+ valid_data: Dict[str, Any] | None
186
+
187
+ if val_data_path and not streaming:
188
+ # Use specified validation dataset path
189
+ logger.info(
190
+ f"Validation using specified validation dataset path: {val_data_path}"
191
+ )
192
+ val_data_resolved = resolve_path(val_data_path, config_dir)
193
+ val_df = read_table(val_data_resolved, data_cfg.get("format"))
194
+ val_df = val_df[unique_used_columns]
195
+ if not isinstance(processed, dict):
196
+ raise TypeError("Processed data must be a dictionary")
197
+ train_data = processed
198
+ valid_data_result = processor.transform(val_df, return_dict=True)
199
+ if not isinstance(valid_data_result, dict):
200
+ raise TypeError("Validation data must be a dictionary")
201
+ valid_data = valid_data_result
202
+ train_size = len(list(train_data.values())[0])
203
+ valid_size = len(list(valid_data.values())[0])
204
+ logger.info(
205
+ f"Sample count - Training set: {train_size}, Validation set: {valid_size}"
206
+ )
207
+ elif streaming:
208
+ train_data = None # type: ignore[assignment]
209
+ valid_data = None
210
+ if not val_data_path and not streaming_valid_files:
211
+ logger.info(
212
+ "Streaming training mode: No validation dataset path specified and valid_ratio not configured, skipping validation dataset creation"
213
+ )
214
+ else:
215
+ # Split data using valid_ratio
216
+ logger.info(
217
+ f"Splitting data using valid_ratio: {data_cfg.get('valid_ratio', 0.2)}"
218
+ )
219
+ if not isinstance(processed, dict):
220
+ raise TypeError("Processed data must be a dictionary for splitting")
221
+ train_data, valid_data = split_dict_random(
222
+ processed,
223
+ test_size=data_cfg.get("valid_ratio", 0.2),
224
+ random_state=data_cfg.get("random_state", 2024),
225
+ )
226
+
227
+ dataloader = RecDataLoader(
228
+ dense_features=dense_features,
229
+ sparse_features=sparse_features,
230
+ sequence_features=sequence_features,
231
+ target=target,
232
+ id_columns=id_columns,
233
+ processor=processor if streaming else None,
234
+ )
235
+ if streaming:
236
+ train_stream_source = streaming_train_files or file_paths
237
+ train_loader = dataloader.create_dataloader(
238
+ data=train_stream_source,
239
+ batch_size=dataloader_cfg.get("train_batch_size", 512),
240
+ shuffle=dataloader_cfg.get("train_shuffle", True),
241
+ load_full=False,
242
+ chunk_size=dataloader_chunk_size,
243
+ num_workers=dataloader_cfg.get("num_workers", 0),
244
+ )
245
+ valid_loader = None
246
+ if val_data_path:
247
+ val_data_resolved = resolve_path(val_data_path, config_dir)
248
+ valid_loader = dataloader.create_dataloader(
249
+ data=str(val_data_resolved),
250
+ batch_size=dataloader_cfg.get("valid_batch_size", 512),
251
+ shuffle=dataloader_cfg.get("valid_shuffle", False),
252
+ load_full=False,
253
+ chunk_size=dataloader_chunk_size,
254
+ num_workers=dataloader_cfg.get("num_workers", 0),
255
+ )
256
+ elif streaming_valid_files:
257
+ valid_loader = dataloader.create_dataloader(
258
+ data=streaming_valid_files,
259
+ batch_size=dataloader_cfg.get("valid_batch_size", 512),
260
+ shuffle=dataloader_cfg.get("valid_shuffle", False),
261
+ load_full=False,
262
+ chunk_size=dataloader_chunk_size,
263
+ num_workers=dataloader_cfg.get("num_workers", 0),
264
+ )
265
+ else:
266
+ train_loader = dataloader.create_dataloader(
267
+ data=train_data,
268
+ batch_size=dataloader_cfg.get("train_batch_size", 512),
269
+ shuffle=dataloader_cfg.get("train_shuffle", True),
270
+ num_workers=dataloader_cfg.get("num_workers", 0),
271
+ )
272
+ valid_loader = dataloader.create_dataloader(
273
+ data=valid_data,
274
+ batch_size=dataloader_cfg.get("valid_batch_size", 512),
275
+ shuffle=dataloader_cfg.get("valid_shuffle", False),
276
+ num_workers=dataloader_cfg.get("num_workers", 0),
277
+ )
278
+
279
+ model_cfg.setdefault("session_id", session_id)
280
+ train_cfg = cfg.get("train", {}) or {}
281
+ device = train_cfg.get("device", model_cfg.get("device", "cpu"))
282
+ model = build_model_instance(
283
+ model_cfg,
284
+ model_cfg_path,
285
+ dense_features,
286
+ sparse_features,
287
+ sequence_features,
288
+ target,
289
+ device,
290
+ )
291
+
292
+ model.compile(
293
+ optimizer=train_cfg.get("optimizer", "adam"),
294
+ optimizer_params=train_cfg.get("optimizer_params", {}),
295
+ loss=train_cfg.get("loss", "focal"),
296
+ loss_params=train_cfg.get("loss_params", {}),
297
+ )
298
+
299
+ model.fit(
300
+ train_data=train_loader,
301
+ valid_data=valid_loader,
302
+ metrics=train_cfg.get("metrics", ["auc", "recall", "precision"]),
303
+ epochs=train_cfg.get("epochs", 1),
304
+ batch_size=train_cfg.get(
305
+ "batch_size", dataloader_cfg.get("train_batch_size", 512)
306
+ ),
307
+ shuffle=train_cfg.get("shuffle", True),
308
+ num_workers=dataloader_cfg.get("num_workers", 0),
309
+ user_id_column=id_column,
310
+ tensorboard=False,
311
+ )
312
+
313
+
314
+ def predict_model(predict_config_path: str) -> None:
315
+ """
316
+ Run prediction using a trained model and configuration file.
317
+ """
318
+ config_file = Path(predict_config_path)
319
+ config_dir = config_file.resolve().parent
320
+ cfg = read_yaml(config_file)
321
+
322
+ session_cfg = cfg.get("session", {}) or {}
323
+ session_id = session_cfg.get("id", "masknet_tutorial")
324
+ artifact_root = Path(session_cfg.get("artifact_root", "nextrec_logs"))
325
+ session_dir = Path(cfg.get("checkpoint_path") or (artifact_root / session_id))
326
+ setup_logger(session_id=session_id)
327
+
328
+ processor_path = Path(session_dir / "processor.pkl")
329
+ if not processor_path.exists():
330
+ processor_path = session_dir / "processor" / "processor.pkl"
331
+
332
+ predict_cfg = cfg.get("predict", {}) or {}
333
+ model_cfg_path = resolve_path(
334
+ cfg.get("model_config", "model_config.yaml"), config_dir
335
+ )
336
+ # feature_cfg_path = resolve_path(
337
+ # cfg.get("feature_config", "feature_config.yaml"), config_dir
338
+ # )
339
+
340
+ model_cfg = read_yaml(model_cfg_path)
341
+ # feature_cfg = read_yaml(feature_cfg_path)
342
+ model_cfg.setdefault("session_id", session_id)
343
+ model_cfg.setdefault("params", {})
344
+
345
+ processor = DataProcessor.load(processor_path)
346
+
347
+ # Load checkpoint and ensure required parameters are passed
348
+ checkpoint_base = Path(session_dir)
349
+ if checkpoint_base.is_dir():
350
+ candidates = sorted(checkpoint_base.glob("*.model"))
351
+ if not candidates:
352
+ raise FileNotFoundError(
353
+ f"[NextRec CLI Error]: Unable to find model checkpoint: {checkpoint_base}"
354
+ )
355
+ model_file = candidates[-1]
356
+ config_dir_for_features = checkpoint_base
357
+ else:
358
+ model_file = (
359
+ checkpoint_base.with_suffix(".model")
360
+ if checkpoint_base.suffix == ""
361
+ else checkpoint_base
362
+ )
363
+ config_dir_for_features = model_file.parent
364
+
365
+ features_config_path = config_dir_for_features / "features_config.pkl"
366
+ if not features_config_path.exists():
367
+ raise FileNotFoundError(
368
+ f"[NextRec CLI Error]: Unable to find features_config.pkl: {features_config_path}"
369
+ )
370
+ with open(features_config_path, "rb") as f:
371
+ features_config = pickle.load(f)
372
+
373
+ all_features = features_config.get("all_features", [])
374
+ target_cols = features_config.get("target", [])
375
+ id_columns = features_config.get("id_columns", [])
376
+
377
+ dense_features = [f for f in all_features if isinstance(f, DenseFeature)]
378
+ sparse_features = [f for f in all_features if isinstance(f, SparseFeature)]
379
+ sequence_features = [f for f in all_features if isinstance(f, SequenceFeature)]
380
+
381
+ target_override = (
382
+ cfg.get("targets")
383
+ or model_cfg.get("targets")
384
+ or model_cfg.get("params", {}).get("targets")
385
+ or model_cfg.get("params", {}).get("target")
386
+ )
387
+ if target_override:
388
+ target_cols = normalize_to_list(target_override)
389
+
390
+ model = build_model_instance(
391
+ model_cfg=model_cfg,
392
+ model_cfg_path=model_cfg_path,
393
+ dense_features=dense_features,
394
+ sparse_features=sparse_features,
395
+ sequence_features=sequence_features,
396
+ target=target_cols,
397
+ device=predict_cfg.get("device", "cpu"),
398
+ )
399
+ model.id_columns = id_columns
400
+ model.load_model(
401
+ model_file, map_location=predict_cfg.get("device", "cpu"), verbose=True
402
+ )
403
+
404
+ id_columns = []
405
+ if predict_cfg.get("id_column"):
406
+ id_columns = [predict_cfg["id_column"]]
407
+ model.id_columns = id_columns
408
+
409
+ rec_dataloader = RecDataLoader(
410
+ dense_features=model.dense_features,
411
+ sparse_features=model.sparse_features,
412
+ sequence_features=model.sequence_features,
413
+ target=None,
414
+ id_columns=id_columns or model.id_columns,
415
+ processor=processor,
416
+ )
417
+
418
+ data_path = resolve_path(predict_cfg["data_path"], config_dir)
419
+ batch_size = predict_cfg.get("batch_size", 512)
420
+
421
+ pred_loader = rec_dataloader.create_dataloader(
422
+ data=str(data_path),
423
+ batch_size=batch_size,
424
+ shuffle=False,
425
+ load_full=predict_cfg.get("load_full", False),
426
+ chunk_size=predict_cfg.get("chunk_size", 20000),
427
+ )
428
+
429
+ output_path = resolve_path(predict_cfg["output_path"], config_dir)
430
+ output_path.parent.mkdir(parents=True, exist_ok=True)
431
+
432
+ start = time.time()
433
+ model.predict(
434
+ data=pred_loader,
435
+ batch_size=batch_size,
436
+ include_ids=bool(id_columns),
437
+ return_dataframe=False,
438
+ save_path=output_path,
439
+ save_format=predict_cfg.get("save_format", "csv"),
440
+ num_workers=predict_cfg.get("num_workers", 0),
441
+ )
442
+ duration = time.time() - start
443
+ logger.info(f"Prediction completed, results saved to: {output_path}")
444
+ logger.info(f"Total time: {duration:.2f} seconds")
445
+
446
+ preview_rows = predict_cfg.get("preview_rows", 0)
447
+ if preview_rows > 0:
448
+ try:
449
+ preview = pd.read_csv(output_path, nrows=preview_rows, low_memory=False)
450
+ logger.info(f"Output preview:\n{preview}")
451
+ except Exception as exc: # pragma: no cover
452
+ logger.warning(f"Failed to read output preview: {exc}")
453
+
454
+
455
+ def main() -> None:
456
+ """Parse CLI arguments and dispatch to train or predict mode."""
457
+ parser = argparse.ArgumentParser(
458
+ description="NextRec: Training and Prediction Pipeline",
459
+ formatter_class=argparse.RawDescriptionHelpFormatter,
460
+ epilog="""
461
+ Examples:
462
+ # Train a model
463
+ nextrec --mode=train --train_config=configs/train_config.yaml
464
+
465
+ # Run prediction
466
+ nextrec --mode=predict --predict_config=configs/predict_config.yaml
467
+ """,
468
+ )
469
+ parser.add_argument(
470
+ "--mode",
471
+ choices=["train", "predict"],
472
+ required=True,
473
+ help="Running mode: train or predict",
474
+ )
475
+ parser.add_argument("--train_config", help="Training configuration file path")
476
+ parser.add_argument("--predict_config", help="Prediction configuration file path")
477
+ args = parser.parse_args()
478
+
479
+ if args.mode == "train":
480
+ config_path = args.train_config
481
+ if not config_path:
482
+ parser.error("[NextRec CLI Error] train mode requires --train_config")
483
+ train_model(config_path)
484
+ else:
485
+ config_path = args.predict_config
486
+ if not config_path:
487
+ parser.error("[NextRec CLI Error] predict mode requires --predict_config")
488
+ predict_model(config_path)
489
+
490
+
491
+ if __name__ == "__main__":
492
+ main()
nextrec/data/__init__.py CHANGED
@@ -27,35 +27,29 @@ from nextrec.data import data_utils
27
27
 
28
28
  __all__ = [
29
29
  # Batch utilities
30
- 'collate_fn',
31
- 'batch_to_dict',
32
- 'stack_section',
33
-
30
+ "collate_fn",
31
+ "batch_to_dict",
32
+ "stack_section",
34
33
  # Data processing
35
- 'get_column_data',
36
- 'split_dict_random',
37
- 'build_eval_candidates',
38
- 'get_user_ids',
39
-
34
+ "get_column_data",
35
+ "split_dict_random",
36
+ "build_eval_candidates",
37
+ "get_user_ids",
40
38
  # File utilities
41
- 'resolve_file_paths',
42
- 'iter_file_chunks',
43
- 'read_table',
44
- 'load_dataframes',
45
- 'default_output_dir',
46
-
39
+ "resolve_file_paths",
40
+ "iter_file_chunks",
41
+ "read_table",
42
+ "load_dataframes",
43
+ "default_output_dir",
47
44
  # DataLoader
48
- 'TensorDictDataset',
49
- 'FileDataset',
50
- 'RecDataLoader',
51
- 'build_tensors_from_data',
52
-
45
+ "TensorDictDataset",
46
+ "FileDataset",
47
+ "RecDataLoader",
48
+ "build_tensors_from_data",
53
49
  # Preprocessor
54
- 'DataProcessor',
55
-
50
+ "DataProcessor",
56
51
  # Features
57
- 'FeatureSet',
58
-
52
+ "FeatureSet",
59
53
  # Legacy module
60
- 'data_utils',
54
+ "data_utils",
61
55
  ]
@@ -9,16 +9,22 @@ import torch
9
9
  import numpy as np
10
10
  from typing import Any, Mapping
11
11
 
12
+
12
13
  def stack_section(batch: list[dict], section: str):
13
14
  entries = [item.get(section) for item in batch if item.get(section) is not None]
14
15
  if not entries:
15
16
  return None
16
17
  merged: dict = {}
17
18
  for name in entries[0]: # type: ignore
18
- tensors = [item[section][name] for item in batch if item.get(section) is not None and name in item[section]]
19
+ tensors = [
20
+ item[section][name]
21
+ for item in batch
22
+ if item.get(section) is not None and name in item[section]
23
+ ]
19
24
  merged[name] = torch.stack(tensors, dim=0)
20
25
  return merged
21
26
 
27
+
22
28
  def collate_fn(batch):
23
29
  """
24
30
  Collate a list of sample dicts into the unified batch format:
@@ -28,7 +34,7 @@ def collate_fn(batch):
28
34
  "ids": {id_name: Tensor(B, ...)} or None,
29
35
  }
30
36
  Args: batch: List of samples from DataLoader
31
-
37
+
32
38
  Returns: dict: Batched data in unified format
33
39
  """
34
40
  if not batch:
@@ -72,7 +78,9 @@ def collate_fn(batch):
72
78
 
73
79
  def batch_to_dict(batch_data: Any, include_ids: bool = True) -> dict:
74
80
  if not (isinstance(batch_data, Mapping) and "features" in batch_data):
75
- raise TypeError("[BaseModel-batch_to_dict Error] Batch data must be a dict with 'features' produced by the current DataLoader.")
81
+ raise TypeError(
82
+ "[BaseModel-batch_to_dict Error] Batch data must be a dict with 'features' produced by the current DataLoader."
83
+ )
76
84
  return {
77
85
  "features": batch_data.get("features", {}),
78
86
  "labels": batch_data.get("labels"),