nextrec 0.4.8__py3-none-any.whl → 0.4.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/callback.py +30 -15
- nextrec/basic/features.py +1 -0
- nextrec/basic/layers.py +6 -8
- nextrec/basic/loggers.py +14 -7
- nextrec/basic/metrics.py +6 -76
- nextrec/basic/model.py +316 -321
- nextrec/cli.py +185 -43
- nextrec/data/__init__.py +13 -16
- nextrec/data/batch_utils.py +3 -2
- nextrec/data/data_processing.py +10 -2
- nextrec/data/data_utils.py +9 -14
- nextrec/data/dataloader.py +31 -33
- nextrec/data/preprocessor.py +328 -255
- nextrec/loss/__init__.py +1 -5
- nextrec/loss/loss_utils.py +2 -8
- nextrec/models/generative/__init__.py +1 -8
- nextrec/models/generative/hstu.py +6 -4
- nextrec/models/multi_task/esmm.py +2 -2
- nextrec/models/multi_task/mmoe.py +2 -2
- nextrec/models/multi_task/ple.py +2 -2
- nextrec/models/multi_task/poso.py +2 -3
- nextrec/models/multi_task/share_bottom.py +2 -2
- nextrec/models/ranking/afm.py +2 -2
- nextrec/models/ranking/autoint.py +2 -2
- nextrec/models/ranking/dcn.py +2 -2
- nextrec/models/ranking/dcn_v2.py +2 -2
- nextrec/models/ranking/deepfm.py +6 -7
- nextrec/models/ranking/dien.py +3 -3
- nextrec/models/ranking/din.py +3 -3
- nextrec/models/ranking/eulernet.py +365 -0
- nextrec/models/ranking/fibinet.py +5 -5
- nextrec/models/ranking/fm.py +3 -7
- nextrec/models/ranking/lr.py +120 -0
- nextrec/models/ranking/masknet.py +2 -2
- nextrec/models/ranking/pnn.py +2 -2
- nextrec/models/ranking/widedeep.py +2 -2
- nextrec/models/ranking/xdeepfm.py +2 -2
- nextrec/models/representation/__init__.py +9 -0
- nextrec/models/{generative → representation}/rqvae.py +9 -9
- nextrec/models/retrieval/__init__.py +0 -0
- nextrec/models/{match → retrieval}/dssm.py +8 -3
- nextrec/models/{match → retrieval}/dssm_v2.py +8 -3
- nextrec/models/{match → retrieval}/mind.py +4 -3
- nextrec/models/{match → retrieval}/sdm.py +4 -3
- nextrec/models/{match → retrieval}/youtube_dnn.py +8 -3
- nextrec/utils/__init__.py +60 -46
- nextrec/utils/config.py +8 -7
- nextrec/utils/console.py +371 -0
- nextrec/utils/{synthetic_data.py → data.py} +102 -15
- nextrec/utils/feature.py +15 -0
- nextrec/utils/torch_utils.py +411 -0
- {nextrec-0.4.8.dist-info → nextrec-0.4.10.dist-info}/METADATA +6 -7
- nextrec-0.4.10.dist-info/RECORD +70 -0
- nextrec/utils/cli_utils.py +0 -58
- nextrec/utils/device.py +0 -78
- nextrec/utils/distributed.py +0 -141
- nextrec/utils/file.py +0 -92
- nextrec/utils/initializer.py +0 -79
- nextrec/utils/optimizer.py +0 -75
- nextrec/utils/tensor.py +0 -72
- nextrec-0.4.8.dist-info/RECORD +0 -71
- /nextrec/models/{match/__init__.py → ranking/ffm.py} +0 -0
- {nextrec-0.4.8.dist-info → nextrec-0.4.10.dist-info}/WHEEL +0 -0
- {nextrec-0.4.8.dist-info → nextrec-0.4.10.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.8.dist-info → nextrec-0.4.10.dist-info}/licenses/LICENSE +0 -0
nextrec/cli.py
CHANGED
|
@@ -18,10 +18,10 @@ Checkpoint: edit on 18/12/2025
|
|
|
18
18
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
19
19
|
"""
|
|
20
20
|
|
|
21
|
-
import sys
|
|
22
21
|
import argparse
|
|
23
22
|
import logging
|
|
24
23
|
import pickle
|
|
24
|
+
import sys
|
|
25
25
|
import time
|
|
26
26
|
from pathlib import Path
|
|
27
27
|
from typing import Any, Dict, List
|
|
@@ -29,6 +29,7 @@ from typing import Any, Dict, List
|
|
|
29
29
|
import pandas as pd
|
|
30
30
|
|
|
31
31
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
32
|
+
from nextrec.basic.loggers import colorize, format_kv, setup_logger
|
|
32
33
|
from nextrec.data.data_utils import split_dict_random
|
|
33
34
|
from nextrec.data.dataloader import RecDataLoader
|
|
34
35
|
from nextrec.data.preprocessor import DataProcessor
|
|
@@ -39,22 +40,29 @@ from nextrec.utils.config import (
|
|
|
39
40
|
resolve_path,
|
|
40
41
|
select_features,
|
|
41
42
|
)
|
|
42
|
-
from nextrec.utils.
|
|
43
|
-
from nextrec.utils.
|
|
43
|
+
from nextrec.utils.console import get_nextrec_version
|
|
44
|
+
from nextrec.utils.data import (
|
|
44
45
|
iter_file_chunks,
|
|
45
46
|
read_table,
|
|
46
47
|
read_yaml,
|
|
47
48
|
resolve_file_paths,
|
|
48
49
|
)
|
|
49
|
-
from nextrec.utils.
|
|
50
|
-
get_nextrec_version,
|
|
51
|
-
log_startup_info,
|
|
52
|
-
)
|
|
53
|
-
from nextrec.basic.loggers import setup_logger
|
|
50
|
+
from nextrec.utils.feature import normalize_to_list
|
|
54
51
|
|
|
55
52
|
logger = logging.getLogger(__name__)
|
|
56
53
|
|
|
57
54
|
|
|
55
|
+
def log_cli_section(title: str) -> None:
|
|
56
|
+
logger.info("")
|
|
57
|
+
logger.info(colorize(f"[{title}]", color="bright_blue", bold=True))
|
|
58
|
+
logger.info(colorize("-" * 80, color="bright_blue"))
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def log_kv_lines(items: list[tuple[str, Any]]) -> None:
|
|
62
|
+
for label, value in items:
|
|
63
|
+
logger.info(format_kv(label, value))
|
|
64
|
+
|
|
65
|
+
|
|
58
66
|
def train_model(train_config_path: str) -> None:
|
|
59
67
|
"""
|
|
60
68
|
Train a NextRec model using the provided configuration file.
|
|
@@ -77,8 +85,17 @@ def train_model(train_config_path: str) -> None:
|
|
|
77
85
|
artifact_root = Path(session_cfg.get("artifact_root", "nextrec_logs"))
|
|
78
86
|
session_dir = artifact_root / session_id
|
|
79
87
|
setup_logger(session_id=session_id)
|
|
80
|
-
|
|
81
|
-
|
|
88
|
+
|
|
89
|
+
log_cli_section("CLI")
|
|
90
|
+
log_kv_lines(
|
|
91
|
+
[
|
|
92
|
+
("Mode", "train"),
|
|
93
|
+
("Version", get_nextrec_version()),
|
|
94
|
+
("Session ID", session_id),
|
|
95
|
+
("Artifacts", session_dir.resolve()),
|
|
96
|
+
("Config", config_file.resolve()),
|
|
97
|
+
("Command", " ".join(sys.argv)),
|
|
98
|
+
]
|
|
82
99
|
)
|
|
83
100
|
|
|
84
101
|
processor_path = session_dir / "processor.pkl"
|
|
@@ -105,11 +122,53 @@ def train_model(train_config_path: str) -> None:
|
|
|
105
122
|
cfg.get("model_config", "model_config.yaml"), config_dir
|
|
106
123
|
)
|
|
107
124
|
|
|
125
|
+
log_cli_section("Config")
|
|
126
|
+
log_kv_lines(
|
|
127
|
+
[
|
|
128
|
+
("Train config", config_file.resolve()),
|
|
129
|
+
("Feature config", feature_cfg_path),
|
|
130
|
+
("Model config", model_cfg_path),
|
|
131
|
+
]
|
|
132
|
+
)
|
|
133
|
+
|
|
108
134
|
feature_cfg = read_yaml(feature_cfg_path)
|
|
109
135
|
model_cfg = read_yaml(model_cfg_path)
|
|
110
136
|
|
|
137
|
+
# Extract id_column from data config for GAUC metrics
|
|
138
|
+
id_column = data_cfg.get("id_column") or data_cfg.get("user_id_column")
|
|
139
|
+
id_columns = [id_column] if id_column else []
|
|
140
|
+
|
|
141
|
+
log_cli_section("Data")
|
|
142
|
+
log_kv_lines(
|
|
143
|
+
[
|
|
144
|
+
("Data path", data_path),
|
|
145
|
+
("Format", data_cfg.get("format", "auto")),
|
|
146
|
+
("Streaming", streaming),
|
|
147
|
+
("Target", target),
|
|
148
|
+
("ID column", id_column or "(not set)"),
|
|
149
|
+
]
|
|
150
|
+
)
|
|
151
|
+
if data_cfg.get("valid_ratio") is not None:
|
|
152
|
+
logger.info(format_kv("Valid ratio", data_cfg.get("valid_ratio")))
|
|
153
|
+
if data_cfg.get("val_path") or data_cfg.get("valid_path"):
|
|
154
|
+
logger.info(
|
|
155
|
+
format_kv(
|
|
156
|
+
"Validation path",
|
|
157
|
+
resolve_path(
|
|
158
|
+
data_cfg.get("val_path") or data_cfg.get("valid_path"), config_dir
|
|
159
|
+
),
|
|
160
|
+
)
|
|
161
|
+
)
|
|
162
|
+
|
|
111
163
|
if streaming:
|
|
112
164
|
file_paths, file_type = resolve_file_paths(str(data_path))
|
|
165
|
+
log_kv_lines(
|
|
166
|
+
[
|
|
167
|
+
("File type", file_type),
|
|
168
|
+
("Files", len(file_paths)),
|
|
169
|
+
("Chunk size", dataloader_chunk_size),
|
|
170
|
+
]
|
|
171
|
+
)
|
|
113
172
|
first_file = file_paths[0]
|
|
114
173
|
first_chunk_size = max(1, min(dataloader_chunk_size, 1000))
|
|
115
174
|
chunk_iter = iter_file_chunks(first_file, file_type, first_chunk_size)
|
|
@@ -121,14 +180,12 @@ def train_model(train_config_path: str) -> None:
|
|
|
121
180
|
|
|
122
181
|
else:
|
|
123
182
|
df = read_table(data_path, data_cfg.get("format"))
|
|
183
|
+
logger.info(format_kv("Rows", len(df)))
|
|
184
|
+
logger.info(format_kv("Columns", len(df.columns)))
|
|
124
185
|
df_columns = list(df.columns)
|
|
125
186
|
|
|
126
187
|
dense_names, sparse_names, sequence_names = select_features(feature_cfg, df_columns)
|
|
127
188
|
|
|
128
|
-
# Extract id_column from data config for GAUC metrics
|
|
129
|
-
id_column = data_cfg.get("id_column") or data_cfg.get("user_id_column")
|
|
130
|
-
id_columns = [id_column] if id_column else []
|
|
131
|
-
|
|
132
189
|
used_columns = dense_names + sparse_names + sequence_names + target + id_columns
|
|
133
190
|
|
|
134
191
|
# keep order but drop duplicates
|
|
@@ -144,6 +201,17 @@ def train_model(train_config_path: str) -> None:
|
|
|
144
201
|
processor, feature_cfg, dense_names, sparse_names, sequence_names
|
|
145
202
|
)
|
|
146
203
|
|
|
204
|
+
log_cli_section("Features")
|
|
205
|
+
log_kv_lines(
|
|
206
|
+
[
|
|
207
|
+
("Dense features", len(dense_names)),
|
|
208
|
+
("Sparse features", len(sparse_names)),
|
|
209
|
+
("Sequence features", len(sequence_names)),
|
|
210
|
+
("Targets", len(target)),
|
|
211
|
+
("Used columns", len(unique_used_columns)),
|
|
212
|
+
]
|
|
213
|
+
)
|
|
214
|
+
|
|
147
215
|
if streaming:
|
|
148
216
|
processor.fit(str(data_path), chunk_size=dataloader_chunk_size)
|
|
149
217
|
processed = None
|
|
@@ -247,7 +315,7 @@ def train_model(train_config_path: str) -> None:
|
|
|
247
315
|
data=train_stream_source,
|
|
248
316
|
batch_size=dataloader_cfg.get("train_batch_size", 512),
|
|
249
317
|
shuffle=dataloader_cfg.get("train_shuffle", True),
|
|
250
|
-
|
|
318
|
+
streaming=True,
|
|
251
319
|
chunk_size=dataloader_chunk_size,
|
|
252
320
|
num_workers=dataloader_cfg.get("num_workers", 0),
|
|
253
321
|
)
|
|
@@ -258,7 +326,7 @@ def train_model(train_config_path: str) -> None:
|
|
|
258
326
|
data=str(val_data_resolved),
|
|
259
327
|
batch_size=dataloader_cfg.get("valid_batch_size", 512),
|
|
260
328
|
shuffle=dataloader_cfg.get("valid_shuffle", False),
|
|
261
|
-
|
|
329
|
+
streaming=True,
|
|
262
330
|
chunk_size=dataloader_chunk_size,
|
|
263
331
|
num_workers=dataloader_cfg.get("num_workers", 0),
|
|
264
332
|
)
|
|
@@ -267,7 +335,7 @@ def train_model(train_config_path: str) -> None:
|
|
|
267
335
|
data=streaming_valid_files,
|
|
268
336
|
batch_size=dataloader_cfg.get("valid_batch_size", 512),
|
|
269
337
|
shuffle=dataloader_cfg.get("valid_shuffle", False),
|
|
270
|
-
|
|
338
|
+
streaming=True,
|
|
271
339
|
chunk_size=dataloader_chunk_size,
|
|
272
340
|
num_workers=dataloader_cfg.get("num_workers", 0),
|
|
273
341
|
)
|
|
@@ -298,6 +366,15 @@ def train_model(train_config_path: str) -> None:
|
|
|
298
366
|
device,
|
|
299
367
|
)
|
|
300
368
|
|
|
369
|
+
log_cli_section("Model")
|
|
370
|
+
log_kv_lines(
|
|
371
|
+
[
|
|
372
|
+
("Model", model.__class__.__name__),
|
|
373
|
+
("Device", device),
|
|
374
|
+
("Session ID", session_id),
|
|
375
|
+
]
|
|
376
|
+
)
|
|
377
|
+
|
|
301
378
|
model.compile(
|
|
302
379
|
optimizer=train_cfg.get("optimizer", "adam"),
|
|
303
380
|
optimizer_params=train_cfg.get("optimizer_params", {}),
|
|
@@ -328,13 +405,30 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
328
405
|
config_dir = config_file.resolve().parent
|
|
329
406
|
cfg = read_yaml(config_file)
|
|
330
407
|
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
408
|
+
# Checkpoint path is the primary configuration
|
|
409
|
+
if "checkpoint_path" not in cfg:
|
|
410
|
+
session_cfg = cfg.get("session", {}) or {}
|
|
411
|
+
session_id = session_cfg.get("id", "nextrec_session")
|
|
412
|
+
artifact_root = Path(session_cfg.get("artifact_root", "nextrec_logs"))
|
|
413
|
+
session_dir = artifact_root / session_id
|
|
414
|
+
else:
|
|
415
|
+
session_dir = Path(cfg["checkpoint_path"])
|
|
416
|
+
# Auto-infer session_id from checkpoint directory name
|
|
417
|
+
session_cfg = cfg.get("session", {}) or {}
|
|
418
|
+
session_id = session_cfg.get("id") or session_dir.name
|
|
419
|
+
|
|
335
420
|
setup_logger(session_id=session_id)
|
|
336
|
-
|
|
337
|
-
|
|
421
|
+
|
|
422
|
+
log_cli_section("CLI")
|
|
423
|
+
log_kv_lines(
|
|
424
|
+
[
|
|
425
|
+
("Mode", "predict"),
|
|
426
|
+
("Version", get_nextrec_version()),
|
|
427
|
+
("Session ID", session_id),
|
|
428
|
+
("Checkpoint", session_dir.resolve()),
|
|
429
|
+
("Config", config_file.resolve()),
|
|
430
|
+
("Command", " ".join(sys.argv)),
|
|
431
|
+
]
|
|
338
432
|
)
|
|
339
433
|
|
|
340
434
|
processor_path = Path(session_dir / "processor.pkl")
|
|
@@ -342,24 +436,38 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
342
436
|
processor_path = session_dir / "processor" / "processor.pkl"
|
|
343
437
|
|
|
344
438
|
predict_cfg = cfg.get("predict", {}) or {}
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
439
|
+
|
|
440
|
+
# Auto-find model_config in checkpoint directory if not specified
|
|
441
|
+
if "model_config" in cfg:
|
|
442
|
+
model_cfg_path = resolve_path(cfg["model_config"], config_dir)
|
|
443
|
+
else:
|
|
444
|
+
# Try to find model_config.yaml in checkpoint directory
|
|
445
|
+
auto_model_cfg = session_dir / "model_config.yaml"
|
|
446
|
+
if auto_model_cfg.exists():
|
|
447
|
+
model_cfg_path = auto_model_cfg
|
|
448
|
+
else:
|
|
449
|
+
# Fallback to config directory
|
|
450
|
+
model_cfg_path = resolve_path("model_config.yaml", config_dir)
|
|
351
451
|
|
|
352
452
|
model_cfg = read_yaml(model_cfg_path)
|
|
353
|
-
# feature_cfg = read_yaml(feature_cfg_path)
|
|
354
453
|
model_cfg.setdefault("session_id", session_id)
|
|
355
454
|
model_cfg.setdefault("params", {})
|
|
356
455
|
|
|
456
|
+
log_cli_section("Config")
|
|
457
|
+
log_kv_lines(
|
|
458
|
+
[
|
|
459
|
+
("Predict config", config_file.resolve()),
|
|
460
|
+
("Model config", model_cfg_path),
|
|
461
|
+
("Processor", processor_path),
|
|
462
|
+
]
|
|
463
|
+
)
|
|
464
|
+
|
|
357
465
|
processor = DataProcessor.load(processor_path)
|
|
358
466
|
|
|
359
467
|
# Load checkpoint and ensure required parameters are passed
|
|
360
468
|
checkpoint_base = Path(session_dir)
|
|
361
469
|
if checkpoint_base.is_dir():
|
|
362
|
-
candidates = sorted(checkpoint_base.glob("*.
|
|
470
|
+
candidates = sorted(checkpoint_base.glob("*.pt"))
|
|
363
471
|
if not candidates:
|
|
364
472
|
raise FileNotFoundError(
|
|
365
473
|
f"[NextRec CLI Error]: Unable to find model checkpoint: {checkpoint_base}"
|
|
@@ -368,7 +476,7 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
368
476
|
config_dir_for_features = checkpoint_base
|
|
369
477
|
else:
|
|
370
478
|
model_file = (
|
|
371
|
-
checkpoint_base.with_suffix(".
|
|
479
|
+
checkpoint_base.with_suffix(".pt")
|
|
372
480
|
if checkpoint_base.suffix == ""
|
|
373
481
|
else checkpoint_base
|
|
374
482
|
)
|
|
@@ -418,40 +526,78 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
418
526
|
id_columns = [predict_cfg["id_column"]]
|
|
419
527
|
model.id_columns = id_columns
|
|
420
528
|
|
|
529
|
+
effective_id_columns = id_columns or model.id_columns
|
|
530
|
+
log_cli_section("Features")
|
|
531
|
+
log_kv_lines(
|
|
532
|
+
[
|
|
533
|
+
("Dense features", len(dense_features)),
|
|
534
|
+
("Sparse features", len(sparse_features)),
|
|
535
|
+
("Sequence features", len(sequence_features)),
|
|
536
|
+
("Targets", len(target_cols)),
|
|
537
|
+
("ID columns", len(effective_id_columns)),
|
|
538
|
+
]
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
log_cli_section("Model")
|
|
542
|
+
log_kv_lines(
|
|
543
|
+
[
|
|
544
|
+
("Model", model.__class__.__name__),
|
|
545
|
+
("Checkpoint", model_file),
|
|
546
|
+
("Device", predict_cfg.get("device", "cpu")),
|
|
547
|
+
]
|
|
548
|
+
)
|
|
549
|
+
|
|
421
550
|
rec_dataloader = RecDataLoader(
|
|
422
551
|
dense_features=model.dense_features,
|
|
423
552
|
sparse_features=model.sparse_features,
|
|
424
553
|
sequence_features=model.sequence_features,
|
|
425
554
|
target=None,
|
|
426
|
-
id_columns=
|
|
555
|
+
id_columns=effective_id_columns,
|
|
427
556
|
processor=processor,
|
|
428
557
|
)
|
|
429
558
|
|
|
430
559
|
data_path = resolve_path(predict_cfg["data_path"], config_dir)
|
|
431
560
|
batch_size = predict_cfg.get("batch_size", 512)
|
|
432
561
|
|
|
562
|
+
log_cli_section("Data")
|
|
563
|
+
log_kv_lines(
|
|
564
|
+
[
|
|
565
|
+
("Data path", data_path),
|
|
566
|
+
("Format", predict_cfg.get("source_data_format", predict_cfg.get("data_format", "auto"))),
|
|
567
|
+
("Batch size", batch_size),
|
|
568
|
+
("Chunk size", predict_cfg.get("chunk_size", 20000)),
|
|
569
|
+
("Streaming", predict_cfg.get("streaming", True)),
|
|
570
|
+
]
|
|
571
|
+
)
|
|
572
|
+
logger.info("")
|
|
433
573
|
pred_loader = rec_dataloader.create_dataloader(
|
|
434
574
|
data=str(data_path),
|
|
435
575
|
batch_size=batch_size,
|
|
436
576
|
shuffle=False,
|
|
437
|
-
|
|
577
|
+
streaming=predict_cfg.get("streaming", True),
|
|
438
578
|
chunk_size=predict_cfg.get("chunk_size", 20000),
|
|
439
579
|
)
|
|
440
580
|
|
|
441
|
-
|
|
442
|
-
|
|
581
|
+
# Build output path: {checkpoint_path}/predictions/{name}.{save_data_format}
|
|
582
|
+
save_format = predict_cfg.get("save_data_format", predict_cfg.get("save_format", "csv"))
|
|
583
|
+
pred_name = predict_cfg.get("name", "pred")
|
|
584
|
+
# Pass filename with extension to let model.predict handle path resolution
|
|
585
|
+
save_path = f"{pred_name}.{save_format}"
|
|
443
586
|
|
|
444
587
|
start = time.time()
|
|
445
|
-
|
|
588
|
+
logger.info("")
|
|
589
|
+
result = model.predict(
|
|
446
590
|
data=pred_loader,
|
|
447
591
|
batch_size=batch_size,
|
|
448
592
|
include_ids=bool(id_columns),
|
|
449
593
|
return_dataframe=False,
|
|
450
|
-
save_path=
|
|
451
|
-
save_format=
|
|
594
|
+
save_path=save_path,
|
|
595
|
+
save_format=save_format,
|
|
452
596
|
num_workers=predict_cfg.get("num_workers", 0),
|
|
453
597
|
)
|
|
454
598
|
duration = time.time() - start
|
|
599
|
+
# When return_dataframe=False, result is the actual file path
|
|
600
|
+
output_path = result if isinstance(result, Path) else checkpoint_base / "predictions" / save_path
|
|
455
601
|
logger.info(f"Prediction completed, results saved to: {output_path}")
|
|
456
602
|
logger.info(f"Total time: {duration:.2f} seconds")
|
|
457
603
|
|
|
@@ -495,8 +641,6 @@ Examples:
|
|
|
495
641
|
parser.add_argument("--predict_config", help="Prediction configuration file path")
|
|
496
642
|
args = parser.parse_args()
|
|
497
643
|
|
|
498
|
-
logger.info(get_nextrec_version())
|
|
499
|
-
|
|
500
644
|
if not args.mode:
|
|
501
645
|
parser.error("[NextRec CLI Error] --mode is required (train|predict)")
|
|
502
646
|
|
|
@@ -504,13 +648,11 @@ Examples:
|
|
|
504
648
|
config_path = args.train_config
|
|
505
649
|
if not config_path:
|
|
506
650
|
parser.error("[NextRec CLI Error] train mode requires --train_config")
|
|
507
|
-
log_startup_info(logger, mode="train", config_path=config_path)
|
|
508
651
|
train_model(config_path)
|
|
509
652
|
else:
|
|
510
653
|
config_path = args.predict_config
|
|
511
654
|
if not config_path:
|
|
512
655
|
parser.error("[NextRec CLI Error] predict mode requires --predict_config")
|
|
513
|
-
log_startup_info(logger, mode="predict", config_path=config_path)
|
|
514
656
|
predict_model(config_path)
|
|
515
657
|
|
|
516
658
|
|
nextrec/data/__init__.py
CHANGED
|
@@ -1,29 +1,26 @@
|
|
|
1
|
-
from nextrec.
|
|
1
|
+
from nextrec.basic.features import FeatureSet
|
|
2
|
+
from nextrec.data import data_utils
|
|
3
|
+
from nextrec.data.batch_utils import batch_to_dict, collate_fn, stack_section
|
|
2
4
|
from nextrec.data.data_processing import (
|
|
3
|
-
get_column_data,
|
|
4
|
-
split_dict_random,
|
|
5
5
|
build_eval_candidates,
|
|
6
|
+
get_column_data,
|
|
6
7
|
get_user_ids,
|
|
8
|
+
split_dict_random,
|
|
7
9
|
)
|
|
8
|
-
|
|
9
|
-
from nextrec.utils.file import (
|
|
10
|
-
resolve_file_paths,
|
|
11
|
-
iter_file_chunks,
|
|
12
|
-
read_table,
|
|
13
|
-
load_dataframes,
|
|
14
|
-
default_output_dir,
|
|
15
|
-
)
|
|
16
|
-
|
|
17
10
|
from nextrec.data.dataloader import (
|
|
18
|
-
TensorDictDataset,
|
|
19
11
|
FileDataset,
|
|
20
12
|
RecDataLoader,
|
|
13
|
+
TensorDictDataset,
|
|
21
14
|
build_tensors_from_data,
|
|
22
15
|
)
|
|
23
|
-
|
|
24
16
|
from nextrec.data.preprocessor import DataProcessor
|
|
25
|
-
from nextrec.
|
|
26
|
-
|
|
17
|
+
from nextrec.utils.data import (
|
|
18
|
+
default_output_dir,
|
|
19
|
+
iter_file_chunks,
|
|
20
|
+
load_dataframes,
|
|
21
|
+
read_table,
|
|
22
|
+
resolve_file_paths,
|
|
23
|
+
)
|
|
27
24
|
|
|
28
25
|
__all__ = [
|
|
29
26
|
# Batch utilities
|
nextrec/data/batch_utils.py
CHANGED
|
@@ -5,10 +5,11 @@ Date: create on 03/12/2025
|
|
|
5
5
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
-
import torch
|
|
9
|
-
import numpy as np
|
|
10
8
|
from typing import Any, Mapping
|
|
11
9
|
|
|
10
|
+
import numpy as np
|
|
11
|
+
import torch
|
|
12
|
+
|
|
12
13
|
|
|
13
14
|
def stack_section(batch: list[dict], section: str):
|
|
14
15
|
entries = [item.get(section) for item in batch if item.get(section) is not None]
|
nextrec/data/data_processing.py
CHANGED
|
@@ -2,13 +2,16 @@
|
|
|
2
2
|
Data processing utilities for NextRec
|
|
3
3
|
|
|
4
4
|
Date: create on 03/12/2025
|
|
5
|
+
Checkpoint: edit on 19/12/2025
|
|
5
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
6
7
|
"""
|
|
7
8
|
|
|
8
|
-
import
|
|
9
|
+
import hashlib
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
9
12
|
import numpy as np
|
|
10
13
|
import pandas as pd
|
|
11
|
-
|
|
14
|
+
import torch
|
|
12
15
|
|
|
13
16
|
|
|
14
17
|
def get_column_data(data: dict | pd.DataFrame, name: str):
|
|
@@ -166,3 +169,8 @@ def get_user_ids(
|
|
|
166
169
|
return arr.reshape(arr.shape[0])
|
|
167
170
|
|
|
168
171
|
return None
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def hash_md5_mod(value: str, hash_size: int) -> int:
|
|
175
|
+
digest = hashlib.md5(value.encode("utf-8")).digest()
|
|
176
|
+
return int.from_bytes(digest, byteorder="big", signed=False) % hash_size
|
nextrec/data/data_utils.py
CHANGED
|
@@ -1,30 +1,25 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Data processing utilities for NextRec
|
|
3
|
-
|
|
4
|
-
This module now re-exports functions from specialized submodules:
|
|
5
|
-
- batch_utils: collate_fn, batch_to_dict
|
|
6
|
-
- data_processing: get_column_data, split_dict_random, build_eval_candidates, get_user_ids
|
|
7
|
-
- nextrec.utils.file_utils: resolve_file_paths, iter_file_chunks, read_table, load_dataframes, default_output_dir
|
|
2
|
+
Data processing utilities for NextRec
|
|
8
3
|
|
|
9
4
|
Date: create on 27/10/2025
|
|
10
|
-
Last update:
|
|
5
|
+
Last update: 19/12/2025
|
|
11
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
12
7
|
"""
|
|
13
8
|
|
|
14
9
|
# Import from new organized modules
|
|
15
|
-
from nextrec.data.batch_utils import
|
|
10
|
+
from nextrec.data.batch_utils import batch_to_dict, collate_fn, stack_section
|
|
16
11
|
from nextrec.data.data_processing import (
|
|
17
|
-
get_column_data,
|
|
18
|
-
split_dict_random,
|
|
19
12
|
build_eval_candidates,
|
|
13
|
+
get_column_data,
|
|
20
14
|
get_user_ids,
|
|
15
|
+
split_dict_random,
|
|
21
16
|
)
|
|
22
|
-
from nextrec.utils.
|
|
23
|
-
|
|
17
|
+
from nextrec.utils.data import (
|
|
18
|
+
default_output_dir,
|
|
24
19
|
iter_file_chunks,
|
|
25
|
-
read_table,
|
|
26
20
|
load_dataframes,
|
|
27
|
-
|
|
21
|
+
read_table,
|
|
22
|
+
resolve_file_paths,
|
|
28
23
|
)
|
|
29
24
|
|
|
30
25
|
__all__ = [
|
nextrec/data/dataloader.py
CHANGED
|
@@ -2,33 +2,32 @@
|
|
|
2
2
|
Dataloader definitions
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 19/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
-
import os
|
|
10
|
-
import torch
|
|
11
9
|
import logging
|
|
10
|
+
import os
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import cast
|
|
13
|
+
|
|
12
14
|
import numpy as np
|
|
13
15
|
import pandas as pd
|
|
14
16
|
import pyarrow.parquet as pq
|
|
15
|
-
|
|
16
|
-
from
|
|
17
|
-
from typing import cast
|
|
17
|
+
import torch
|
|
18
|
+
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
|
18
19
|
|
|
19
20
|
from nextrec.basic.features import (
|
|
20
21
|
DenseFeature,
|
|
21
|
-
SparseFeature,
|
|
22
|
-
SequenceFeature,
|
|
23
22
|
FeatureSet,
|
|
23
|
+
SequenceFeature,
|
|
24
|
+
SparseFeature,
|
|
24
25
|
)
|
|
25
|
-
from nextrec.data.preprocessor import DataProcessor
|
|
26
|
-
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
|
27
|
-
|
|
28
|
-
from nextrec.utils.tensor import to_tensor
|
|
29
|
-
from nextrec.utils.file import resolve_file_paths, read_table
|
|
30
26
|
from nextrec.data.batch_utils import collate_fn
|
|
31
27
|
from nextrec.data.data_processing import get_column_data
|
|
28
|
+
from nextrec.data.preprocessor import DataProcessor
|
|
29
|
+
from nextrec.utils.data import read_table, resolve_file_paths
|
|
30
|
+
from nextrec.utils.torch_utils import to_tensor
|
|
32
31
|
|
|
33
32
|
|
|
34
33
|
class TensorDictDataset(Dataset):
|
|
@@ -103,9 +102,8 @@ class FileDataset(FeatureSet, IterableDataset):
|
|
|
103
102
|
self.current_file_index = 0
|
|
104
103
|
for file_path in self.file_paths:
|
|
105
104
|
self.current_file_index += 1
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
logging.info(f"Processing file: {file_name}")
|
|
105
|
+
# Don't log file processing here to avoid interrupting progress bars
|
|
106
|
+
# File information is already displayed in the CLI data section
|
|
109
107
|
if self.file_type == "csv":
|
|
110
108
|
yield from self.read_csv_chunks(file_path)
|
|
111
109
|
elif self.file_type == "parquet":
|
|
@@ -191,7 +189,7 @@ class RecDataLoader(FeatureSet):
|
|
|
191
189
|
),
|
|
192
190
|
batch_size: int = 32,
|
|
193
191
|
shuffle: bool = True,
|
|
194
|
-
|
|
192
|
+
streaming: bool = False,
|
|
195
193
|
chunk_size: int = 10000,
|
|
196
194
|
num_workers: int = 0,
|
|
197
195
|
sampler=None,
|
|
@@ -203,7 +201,7 @@ class RecDataLoader(FeatureSet):
|
|
|
203
201
|
data: Data source, can be a dict, pd.DataFrame, file path (str), or existing DataLoader.
|
|
204
202
|
batch_size: Batch size for DataLoader.
|
|
205
203
|
shuffle: Whether to shuffle the data (ignored in streaming mode).
|
|
206
|
-
|
|
204
|
+
streaming: If True, use streaming mode for large files; if False, load full data into memory.
|
|
207
205
|
chunk_size: Chunk size for streaming mode (number of rows per chunk).
|
|
208
206
|
num_workers: Number of worker processes for data loading.
|
|
209
207
|
sampler: Optional sampler for DataLoader, only used for distributed training.
|
|
@@ -218,7 +216,7 @@ class RecDataLoader(FeatureSet):
|
|
|
218
216
|
path=data,
|
|
219
217
|
batch_size=batch_size,
|
|
220
218
|
shuffle=shuffle,
|
|
221
|
-
|
|
219
|
+
streaming=streaming,
|
|
222
220
|
chunk_size=chunk_size,
|
|
223
221
|
num_workers=num_workers,
|
|
224
222
|
)
|
|
@@ -231,7 +229,7 @@ class RecDataLoader(FeatureSet):
|
|
|
231
229
|
path=data,
|
|
232
230
|
batch_size=batch_size,
|
|
233
231
|
shuffle=shuffle,
|
|
234
|
-
|
|
232
|
+
streaming=streaming,
|
|
235
233
|
chunk_size=chunk_size,
|
|
236
234
|
num_workers=num_workers,
|
|
237
235
|
)
|
|
@@ -291,7 +289,7 @@ class RecDataLoader(FeatureSet):
|
|
|
291
289
|
path: str | os.PathLike | list[str] | list[os.PathLike],
|
|
292
290
|
batch_size: int,
|
|
293
291
|
shuffle: bool,
|
|
294
|
-
|
|
292
|
+
streaming: bool,
|
|
295
293
|
chunk_size: int = 10000,
|
|
296
294
|
num_workers: int = 0,
|
|
297
295
|
) -> DataLoader:
|
|
@@ -312,8 +310,17 @@ class RecDataLoader(FeatureSet):
|
|
|
312
310
|
f"[RecDataLoader Error] Unsupported file extension in list: {suffix}"
|
|
313
311
|
)
|
|
314
312
|
file_type = "csv" if suffix == ".csv" else "parquet"
|
|
313
|
+
if streaming:
|
|
314
|
+
return self.load_files_streaming(
|
|
315
|
+
file_paths,
|
|
316
|
+
file_type,
|
|
317
|
+
batch_size,
|
|
318
|
+
chunk_size,
|
|
319
|
+
shuffle,
|
|
320
|
+
num_workers=num_workers,
|
|
321
|
+
)
|
|
315
322
|
# Load full data into memory
|
|
316
|
-
|
|
323
|
+
else:
|
|
317
324
|
dfs = []
|
|
318
325
|
total_bytes = 0
|
|
319
326
|
for file_path in file_paths:
|
|
@@ -326,26 +333,17 @@ class RecDataLoader(FeatureSet):
|
|
|
326
333
|
dfs.append(df)
|
|
327
334
|
except MemoryError as exc:
|
|
328
335
|
raise MemoryError(
|
|
329
|
-
f"[RecDataLoader Error] Out of memory while reading {file_path}. Consider using
|
|
336
|
+
f"[RecDataLoader Error] Out of memory while reading {file_path}. Consider using streaming=True."
|
|
330
337
|
) from exc
|
|
331
338
|
try:
|
|
332
339
|
combined_df = pd.concat(dfs, ignore_index=True)
|
|
333
340
|
except MemoryError as exc:
|
|
334
341
|
raise MemoryError(
|
|
335
|
-
f"[RecDataLoader Error] Out of memory while concatenating loaded data (approx {total_bytes / (1024**3):.2f} GB). Use
|
|
342
|
+
f"[RecDataLoader Error] Out of memory while concatenating loaded data (approx {total_bytes / (1024**3):.2f} GB). Use streaming=True or reduce chunk_size."
|
|
336
343
|
) from exc
|
|
337
344
|
return self.create_from_memory(
|
|
338
345
|
combined_df, batch_size, shuffle, num_workers=num_workers
|
|
339
346
|
)
|
|
340
|
-
else:
|
|
341
|
-
return self.load_files_streaming(
|
|
342
|
-
file_paths,
|
|
343
|
-
file_type,
|
|
344
|
-
batch_size,
|
|
345
|
-
chunk_size,
|
|
346
|
-
shuffle,
|
|
347
|
-
num_workers=num_workers,
|
|
348
|
-
)
|
|
349
347
|
|
|
350
348
|
def load_files_streaming(
|
|
351
349
|
self,
|