nextrec 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/model.py +288 -181
- nextrec/basic/summary.py +21 -4
- nextrec/cli.py +36 -17
- nextrec/data/__init__.py +0 -52
- nextrec/data/batch_utils.py +1 -1
- nextrec/data/data_processing.py +1 -35
- nextrec/data/data_utils.py +0 -4
- nextrec/data/dataloader.py +125 -103
- nextrec/data/preprocessor.py +141 -92
- nextrec/loss/__init__.py +0 -36
- nextrec/models/generative/__init__.py +0 -9
- nextrec/models/tree_base/__init__.py +0 -15
- nextrec/models/tree_base/base.py +14 -23
- nextrec/utils/__init__.py +0 -119
- nextrec/utils/data.py +39 -119
- nextrec/utils/model.py +5 -14
- nextrec/utils/torch_utils.py +6 -1
- {nextrec-0.5.0.dist-info → nextrec-0.5.2.dist-info}/METADATA +4 -5
- {nextrec-0.5.0.dist-info → nextrec-0.5.2.dist-info}/RECORD +23 -23
- {nextrec-0.5.0.dist-info → nextrec-0.5.2.dist-info}/WHEEL +0 -0
- {nextrec-0.5.0.dist-info → nextrec-0.5.2.dist-info}/entry_points.txt +0 -0
- {nextrec-0.5.0.dist-info → nextrec-0.5.2.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/summary.py
CHANGED
|
@@ -13,10 +13,13 @@ import logging
|
|
|
13
13
|
from typing import Any, Literal
|
|
14
14
|
|
|
15
15
|
import numpy as np
|
|
16
|
+
import pandas as pd
|
|
17
|
+
import polars as pl
|
|
16
18
|
from torch.utils.data import DataLoader
|
|
17
19
|
|
|
18
20
|
from nextrec.basic.loggers import colorize, format_kv
|
|
19
|
-
from nextrec.data.data_processing import
|
|
21
|
+
from nextrec.data.data_processing import get_column_data, get_data_length
|
|
22
|
+
from nextrec.utils.torch_utils import to_numpy
|
|
20
23
|
from nextrec.utils.types import TaskTypeName
|
|
21
24
|
|
|
22
25
|
|
|
@@ -82,9 +85,23 @@ class SummarySet:
|
|
|
82
85
|
if train_size is None:
|
|
83
86
|
train_size = get_data_length(data)
|
|
84
87
|
|
|
85
|
-
labels =
|
|
86
|
-
if
|
|
87
|
-
|
|
88
|
+
labels = None
|
|
89
|
+
if self.target_columns:
|
|
90
|
+
for source in (dataset, data):
|
|
91
|
+
if source is None:
|
|
92
|
+
continue
|
|
93
|
+
label_source = source.labels if hasattr(source, "labels") else source # type: ignore
|
|
94
|
+
if not isinstance(label_source, (dict, pd.DataFrame, pl.DataFrame)):
|
|
95
|
+
continue
|
|
96
|
+
label_map = {}
|
|
97
|
+
for name in self.target_columns:
|
|
98
|
+
column = get_column_data(label_source, name) # type: ignore
|
|
99
|
+
if column is None:
|
|
100
|
+
continue
|
|
101
|
+
label_map[name] = to_numpy(column)
|
|
102
|
+
labels = label_map or None
|
|
103
|
+
if labels:
|
|
104
|
+
break
|
|
88
105
|
|
|
89
106
|
summary = {}
|
|
90
107
|
if train_size is not None:
|
nextrec/cli.py
CHANGED
|
@@ -14,7 +14,7 @@ Examples:
|
|
|
14
14
|
nextrec --mode=predict --predict_config=nextrec_cli_preset/predict_config.yaml
|
|
15
15
|
|
|
16
16
|
Date: create on 06/12/2025
|
|
17
|
-
Checkpoint: edit on
|
|
17
|
+
Checkpoint: edit on 29/01/2026
|
|
18
18
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
19
19
|
"""
|
|
20
20
|
|
|
@@ -112,10 +112,10 @@ def train_model(train_config_path: str) -> None:
|
|
|
112
112
|
# train data
|
|
113
113
|
data_path = resolve_path(data_cfg["path"], config_dir)
|
|
114
114
|
target = to_list(data_cfg["target"])
|
|
115
|
-
file_paths
|
|
116
|
-
file_type
|
|
117
|
-
streaming_train_files
|
|
118
|
-
streaming_valid_files
|
|
115
|
+
file_paths = []
|
|
116
|
+
file_type = None
|
|
117
|
+
streaming_train_files = None
|
|
118
|
+
streaming_valid_files = None
|
|
119
119
|
|
|
120
120
|
feature_cfg_path = resolve_path(
|
|
121
121
|
cfg.get("feature_config", "feature_config.yaml"), config_dir
|
|
@@ -251,7 +251,6 @@ def train_model(train_config_path: str) -> None:
|
|
|
251
251
|
processor.fit_from_files(
|
|
252
252
|
file_paths=streaming_train_files or file_paths,
|
|
253
253
|
file_type=file_type,
|
|
254
|
-
chunk_size=dataloader_chunk_size,
|
|
255
254
|
)
|
|
256
255
|
processed = None
|
|
257
256
|
df = None # type: ignore[assignment]
|
|
@@ -653,6 +652,7 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
653
652
|
streaming = bool(predict_cfg.get("streaming", True))
|
|
654
653
|
chunk_size = int(predict_cfg.get("chunk_size", 20000))
|
|
655
654
|
batch_size = int(predict_cfg.get("batch_size", 512))
|
|
655
|
+
num_processes = int(predict_cfg.get("num_processes", 1))
|
|
656
656
|
effective_batch_size = chunk_size if streaming else batch_size
|
|
657
657
|
|
|
658
658
|
log_cli_section("Data")
|
|
@@ -668,17 +668,35 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
668
668
|
("Batch size", effective_batch_size),
|
|
669
669
|
("Chunk size", chunk_size),
|
|
670
670
|
("Streaming", streaming),
|
|
671
|
+
("Num processes", num_processes),
|
|
671
672
|
]
|
|
672
673
|
)
|
|
674
|
+
if num_processes > 1 and predict_cfg.get("num_workers", 0) != 0:
|
|
675
|
+
logger.info("")
|
|
676
|
+
logger.info(
|
|
677
|
+
"[NextRec CLI Info] Multi-process streaming enforces num_workers=0 for each shard."
|
|
678
|
+
)
|
|
673
679
|
logger.info("")
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
680
|
+
if num_processes > 1:
|
|
681
|
+
if not streaming:
|
|
682
|
+
raise ValueError(
|
|
683
|
+
"[NextRec CLI Error] num_processes > 1 requires streaming=true."
|
|
684
|
+
)
|
|
685
|
+
if use_onnx:
|
|
686
|
+
raise ValueError(
|
|
687
|
+
"[NextRec CLI Error] num_processes > 1 is not supported with ONNX inference."
|
|
688
|
+
)
|
|
689
|
+
pred_data = str(data_path)
|
|
690
|
+
else:
|
|
691
|
+
pred_data = rec_dataloader.create_dataloader(
|
|
692
|
+
data=str(data_path),
|
|
693
|
+
batch_size=1 if streaming else batch_size,
|
|
694
|
+
shuffle=False,
|
|
695
|
+
streaming=streaming,
|
|
696
|
+
chunk_size=chunk_size,
|
|
697
|
+
num_workers=predict_cfg.get("num_workers", 0),
|
|
698
|
+
prefetch_factor=predict_cfg.get("prefetch_factor"),
|
|
699
|
+
)
|
|
682
700
|
|
|
683
701
|
save_format = predict_cfg.get(
|
|
684
702
|
"save_data_format", predict_cfg.get("save_format", "csv")
|
|
@@ -697,7 +715,7 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
697
715
|
if use_onnx:
|
|
698
716
|
result = model.predict_onnx(
|
|
699
717
|
onnx_path=onnx_path,
|
|
700
|
-
data=
|
|
718
|
+
data=pred_data,
|
|
701
719
|
batch_size=effective_batch_size,
|
|
702
720
|
include_ids=bool(id_columns),
|
|
703
721
|
return_dataframe=False,
|
|
@@ -707,13 +725,14 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
707
725
|
)
|
|
708
726
|
else:
|
|
709
727
|
result = model.predict(
|
|
710
|
-
data=
|
|
728
|
+
data=pred_data,
|
|
711
729
|
batch_size=effective_batch_size,
|
|
712
|
-
include_ids=bool(id_columns),
|
|
713
730
|
return_dataframe=False,
|
|
714
731
|
save_path=str(save_path),
|
|
715
732
|
save_format=save_format,
|
|
716
733
|
num_workers=predict_cfg.get("num_workers", 0),
|
|
734
|
+
num_processes=num_processes,
|
|
735
|
+
processor=processor,
|
|
717
736
|
)
|
|
718
737
|
duration = time.time() - start
|
|
719
738
|
# When return_dataframe=False, result is the actual file path
|
nextrec/data/__init__.py
CHANGED
|
@@ -1,52 +0,0 @@
|
|
|
1
|
-
from nextrec.basic.features import FeatureSet
|
|
2
|
-
from nextrec.data import data_utils
|
|
3
|
-
from nextrec.data.batch_utils import batch_to_dict, collate_fn, stack_section
|
|
4
|
-
from nextrec.data.data_processing import (
|
|
5
|
-
build_eval_candidates,
|
|
6
|
-
get_column_data,
|
|
7
|
-
get_user_ids,
|
|
8
|
-
split_dict_random,
|
|
9
|
-
)
|
|
10
|
-
from nextrec.data.dataloader import (
|
|
11
|
-
FileDataset,
|
|
12
|
-
RecDataLoader,
|
|
13
|
-
TensorDictDataset,
|
|
14
|
-
build_tensors_from_data,
|
|
15
|
-
)
|
|
16
|
-
from nextrec.data.preprocessor import DataProcessor
|
|
17
|
-
from nextrec.utils.data import (
|
|
18
|
-
default_output_dir,
|
|
19
|
-
iter_file_chunks,
|
|
20
|
-
load_dataframes,
|
|
21
|
-
read_table,
|
|
22
|
-
resolve_file_paths,
|
|
23
|
-
)
|
|
24
|
-
|
|
25
|
-
__all__ = [
|
|
26
|
-
# Batch utilities
|
|
27
|
-
"collate_fn",
|
|
28
|
-
"batch_to_dict",
|
|
29
|
-
"stack_section",
|
|
30
|
-
# Data processing
|
|
31
|
-
"get_column_data",
|
|
32
|
-
"split_dict_random",
|
|
33
|
-
"build_eval_candidates",
|
|
34
|
-
"get_user_ids",
|
|
35
|
-
# File utilities
|
|
36
|
-
"resolve_file_paths",
|
|
37
|
-
"iter_file_chunks",
|
|
38
|
-
"read_table",
|
|
39
|
-
"load_dataframes",
|
|
40
|
-
"default_output_dir",
|
|
41
|
-
# DataLoader
|
|
42
|
-
"TensorDictDataset",
|
|
43
|
-
"FileDataset",
|
|
44
|
-
"RecDataLoader",
|
|
45
|
-
"build_tensors_from_data",
|
|
46
|
-
# Preprocessor
|
|
47
|
-
"DataProcessor",
|
|
48
|
-
# Features
|
|
49
|
-
"FeatureSet",
|
|
50
|
-
# Legacy module
|
|
51
|
-
"data_utils",
|
|
52
|
-
]
|
nextrec/data/batch_utils.py
CHANGED
|
@@ -64,7 +64,7 @@ def collate_fn(batch):
|
|
|
64
64
|
first = batch[0]
|
|
65
65
|
if isinstance(first, dict) and "features" in first:
|
|
66
66
|
# Streaming dataset yields already-batched chunks; avoid adding an extra dim.
|
|
67
|
-
if first.get("
|
|
67
|
+
if first.get("stream_mode") and len(batch) == 1:
|
|
68
68
|
return {
|
|
69
69
|
"features": first.get("features", {}),
|
|
70
70
|
"labels": first.get("labels"),
|
nextrec/data/data_processing.py
CHANGED
|
@@ -6,7 +6,6 @@ Checkpoint: edit on 25/12/2025
|
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
-
import hashlib
|
|
10
9
|
from typing import Any
|
|
11
10
|
|
|
12
11
|
import numpy as np
|
|
@@ -15,9 +14,6 @@ import torch
|
|
|
15
14
|
import polars as pl
|
|
16
15
|
|
|
17
16
|
|
|
18
|
-
from nextrec.utils.torch_utils import to_numpy
|
|
19
|
-
|
|
20
|
-
|
|
21
17
|
def get_column_data(data: dict | pd.DataFrame | pl.DataFrame, name: str):
|
|
22
18
|
|
|
23
19
|
if isinstance(data, dict):
|
|
@@ -32,8 +28,7 @@ def get_column_data(data: dict | pd.DataFrame | pl.DataFrame, name: str):
|
|
|
32
28
|
|
|
33
29
|
|
|
34
30
|
def get_data_length(data: Any) -> int | None:
|
|
35
|
-
|
|
36
|
-
return None
|
|
31
|
+
|
|
37
32
|
if isinstance(data, pd.DataFrame):
|
|
38
33
|
return len(data)
|
|
39
34
|
if isinstance(data, pl.DataFrame):
|
|
@@ -43,33 +38,9 @@ def get_data_length(data: Any) -> int | None:
|
|
|
43
38
|
return None
|
|
44
39
|
sample_key = next(iter(data))
|
|
45
40
|
return len(data[sample_key])
|
|
46
|
-
try:
|
|
47
|
-
return len(data)
|
|
48
|
-
except TypeError:
|
|
49
|
-
return None
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
def extract_label_arrays(
|
|
53
|
-
data: Any, target_columns: list[str]
|
|
54
|
-
) -> dict[str, np.ndarray] | None:
|
|
55
|
-
if not target_columns or data is None:
|
|
56
|
-
return None
|
|
57
|
-
|
|
58
|
-
if isinstance(data, (dict, pd.DataFrame)):
|
|
59
|
-
label_source = data
|
|
60
|
-
elif hasattr(data, "labels"):
|
|
61
|
-
label_source = data.labels
|
|
62
41
|
else:
|
|
63
42
|
return None
|
|
64
43
|
|
|
65
|
-
labels: dict[str, np.ndarray] = {}
|
|
66
|
-
for name in target_columns:
|
|
67
|
-
column = get_column_data(label_source, name)
|
|
68
|
-
if column is None:
|
|
69
|
-
continue
|
|
70
|
-
labels[name] = to_numpy(column)
|
|
71
|
-
return labels or None
|
|
72
|
-
|
|
73
44
|
|
|
74
45
|
def split_dict_random(data_dict, test_size=0.2, random_state=None):
|
|
75
46
|
|
|
@@ -202,8 +173,3 @@ def get_user_ids(
|
|
|
202
173
|
return arr.reshape(arr.shape[0])
|
|
203
174
|
|
|
204
175
|
return None
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
def hash_md5_mod(value: str, hash_size: int) -> int:
|
|
208
|
-
digest = hashlib.md5(value.encode("utf-8")).digest()
|
|
209
|
-
return int.from_bytes(digest, byteorder="big", signed=False) % hash_size
|
nextrec/data/data_utils.py
CHANGED
|
@@ -15,9 +15,7 @@ from nextrec.data.data_processing import (
|
|
|
15
15
|
split_dict_random,
|
|
16
16
|
)
|
|
17
17
|
from nextrec.utils.data import (
|
|
18
|
-
default_output_dir,
|
|
19
18
|
iter_file_chunks,
|
|
20
|
-
load_dataframes,
|
|
21
19
|
read_table,
|
|
22
20
|
resolve_file_paths,
|
|
23
21
|
)
|
|
@@ -36,6 +34,4 @@ __all__ = [
|
|
|
36
34
|
"resolve_file_paths",
|
|
37
35
|
"iter_file_chunks",
|
|
38
36
|
"read_table",
|
|
39
|
-
"load_dataframes",
|
|
40
|
-
"default_output_dir",
|
|
41
37
|
]
|