nextrec 0.5.1__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/model.py +288 -181
- nextrec/basic/summary.py +21 -4
- nextrec/cli.py +35 -15
- nextrec/data/__init__.py +0 -52
- nextrec/data/batch_utils.py +1 -1
- nextrec/data/data_processing.py +1 -35
- nextrec/data/data_utils.py +0 -4
- nextrec/data/dataloader.py +125 -103
- nextrec/data/preprocessor.py +141 -92
- nextrec/loss/__init__.py +0 -36
- nextrec/models/generative/__init__.py +0 -9
- nextrec/models/tree_base/__init__.py +0 -15
- nextrec/models/tree_base/base.py +14 -23
- nextrec/utils/__init__.py +0 -119
- nextrec/utils/data.py +39 -119
- nextrec/utils/model.py +5 -14
- {nextrec-0.5.1.dist-info → nextrec-0.5.2.dist-info}/METADATA +4 -5
- {nextrec-0.5.1.dist-info → nextrec-0.5.2.dist-info}/RECORD +22 -22
- {nextrec-0.5.1.dist-info → nextrec-0.5.2.dist-info}/WHEEL +0 -0
- {nextrec-0.5.1.dist-info → nextrec-0.5.2.dist-info}/entry_points.txt +0 -0
- {nextrec-0.5.1.dist-info → nextrec-0.5.2.dist-info}/licenses/LICENSE +0 -0
nextrec/models/tree_base/base.py
CHANGED
|
@@ -28,7 +28,6 @@ from nextrec.basic.session import create_session, get_save_path
|
|
|
28
28
|
from nextrec.data.dataloader import RecDataLoader
|
|
29
29
|
from nextrec.data.data_processing import get_column_data
|
|
30
30
|
from nextrec.utils.console import display_metrics_table
|
|
31
|
-
from nextrec.utils.data import FILE_FORMAT_CONFIG, check_streaming_support
|
|
32
31
|
from nextrec.utils.torch_utils import to_list
|
|
33
32
|
from nextrec.utils.torch_utils import to_numpy
|
|
34
33
|
|
|
@@ -454,7 +453,7 @@ class TreeBaseModel(FeatureSet):
|
|
|
454
453
|
stream_chunk_size: int = 10000,
|
|
455
454
|
num_workers: int = 0,
|
|
456
455
|
) -> pd.DataFrame | np.ndarray | Path | None:
|
|
457
|
-
del batch_size
|
|
456
|
+
del batch_size # not used for tree models
|
|
458
457
|
|
|
459
458
|
if self.model is None:
|
|
460
459
|
raise ValueError(f"[{self.model_name}-predict Error] Model is not loaded.")
|
|
@@ -472,6 +471,7 @@ class TreeBaseModel(FeatureSet):
|
|
|
472
471
|
include_ids=include_ids,
|
|
473
472
|
stream_chunk_size=stream_chunk_size,
|
|
474
473
|
id_columns=predict_id_columns,
|
|
474
|
+
num_workers=num_workers,
|
|
475
475
|
)
|
|
476
476
|
|
|
477
477
|
if isinstance(data, (str, os.PathLike)):
|
|
@@ -508,12 +508,13 @@ class TreeBaseModel(FeatureSet):
|
|
|
508
508
|
output = pred_df if return_dataframe else y_pred
|
|
509
509
|
|
|
510
510
|
if save_path is not None:
|
|
511
|
-
|
|
511
|
+
if save_format not in {"csv", "parquet"}:
|
|
512
|
+
raise ValueError(f"Unsupported save format: {save_format}")
|
|
512
513
|
target_path = get_save_path(
|
|
513
514
|
path=save_path,
|
|
514
515
|
default_dir=self.session.predictions_dir,
|
|
515
516
|
default_name="predictions",
|
|
516
|
-
suffix=
|
|
517
|
+
suffix=f".{save_format}",
|
|
517
518
|
add_timestamp=True if save_path is None else False,
|
|
518
519
|
)
|
|
519
520
|
if isinstance(output, pd.DataFrame):
|
|
@@ -527,12 +528,6 @@ class TreeBaseModel(FeatureSet):
|
|
|
527
528
|
df_to_save.to_csv(target_path, index=False)
|
|
528
529
|
elif save_format == "parquet":
|
|
529
530
|
df_to_save.to_parquet(target_path, index=False)
|
|
530
|
-
elif save_format == "feather":
|
|
531
|
-
df_to_save.to_feather(target_path)
|
|
532
|
-
elif save_format == "excel":
|
|
533
|
-
df_to_save.to_excel(target_path, index=False)
|
|
534
|
-
elif save_format == "hdf5":
|
|
535
|
-
df_to_save.to_hdf(target_path, key="predictions", mode="w")
|
|
536
531
|
else:
|
|
537
532
|
raise ValueError(f"Unsupported save format: {save_format}")
|
|
538
533
|
logging.info(f"Predictions saved to: {target_path}")
|
|
@@ -546,6 +541,7 @@ class TreeBaseModel(FeatureSet):
|
|
|
546
541
|
include_ids: bool,
|
|
547
542
|
stream_chunk_size: int,
|
|
548
543
|
id_columns: list[str] | None,
|
|
544
|
+
num_workers: int = 0,
|
|
549
545
|
) -> Path:
|
|
550
546
|
if isinstance(data, (str, os.PathLike)):
|
|
551
547
|
rec_loader = RecDataLoader(
|
|
@@ -561,25 +557,27 @@ class TreeBaseModel(FeatureSet):
|
|
|
561
557
|
shuffle=False,
|
|
562
558
|
streaming=True,
|
|
563
559
|
chunk_size=stream_chunk_size,
|
|
560
|
+
num_workers=num_workers,
|
|
564
561
|
)
|
|
565
562
|
else:
|
|
566
563
|
data_loader = data
|
|
567
564
|
|
|
568
|
-
if not
|
|
565
|
+
if save_format.lower() not in {"csv", "parquet"}:
|
|
569
566
|
logging.warning(
|
|
570
567
|
f"[{self.model_name}-predict Warning] Format '{save_format}' does not support streaming writes."
|
|
571
568
|
)
|
|
572
569
|
|
|
573
|
-
|
|
570
|
+
if save_format not in {"csv", "parquet"}:
|
|
571
|
+
raise ValueError(f"Unsupported save format: {save_format}")
|
|
574
572
|
target_path = get_save_path(
|
|
575
573
|
path=save_path,
|
|
576
574
|
default_dir=self.session.predictions_dir,
|
|
577
575
|
default_name="predictions",
|
|
578
|
-
suffix=
|
|
576
|
+
suffix=f".{save_format}",
|
|
579
577
|
add_timestamp=True if save_path is None else False,
|
|
580
578
|
)
|
|
581
579
|
|
|
582
|
-
header_written =
|
|
580
|
+
header_written = target_path.exists()
|
|
583
581
|
parquet_writer = None
|
|
584
582
|
collected_frames: list[pd.DataFrame] = []
|
|
585
583
|
id_column = id_columns[0] if id_columns else None
|
|
@@ -595,6 +593,7 @@ class TreeBaseModel(FeatureSet):
|
|
|
595
593
|
pred_df.to_csv(
|
|
596
594
|
target_path, mode="a", header=not header_written, index=False
|
|
597
595
|
)
|
|
596
|
+
header_written = True
|
|
598
597
|
elif save_format == "parquet":
|
|
599
598
|
try:
|
|
600
599
|
import pyarrow as pa
|
|
@@ -609,19 +608,11 @@ class TreeBaseModel(FeatureSet):
|
|
|
609
608
|
parquet_writer.write_table(table)
|
|
610
609
|
else:
|
|
611
610
|
collected_frames.append(pred_df)
|
|
612
|
-
header_written = True
|
|
613
611
|
if parquet_writer is not None:
|
|
614
612
|
parquet_writer.close()
|
|
615
613
|
if collected_frames:
|
|
616
614
|
combined_df = pd.concat(collected_frames, ignore_index=True)
|
|
617
|
-
|
|
618
|
-
combined_df.to_feather(target_path)
|
|
619
|
-
elif save_format == "excel":
|
|
620
|
-
combined_df.to_excel(target_path, index=False)
|
|
621
|
-
elif save_format == "hdf5":
|
|
622
|
-
combined_df.to_hdf(target_path, key="predictions", mode="w")
|
|
623
|
-
else:
|
|
624
|
-
raise ValueError(f"Unsupported save format: {save_format}")
|
|
615
|
+
raise ValueError(f"Unsupported save format: {save_format}")
|
|
625
616
|
return target_path
|
|
626
617
|
|
|
627
618
|
def save_model(self, save_path: str | os.PathLike | None = None) -> Path:
|
nextrec/utils/__init__.py
CHANGED
|
@@ -1,119 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Utilities package for NextRec
|
|
3
|
-
|
|
4
|
-
Date: create on 13/11/2025
|
|
5
|
-
Last update: 19/12/2025
|
|
6
|
-
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
|
-
"""
|
|
8
|
-
|
|
9
|
-
from . import console, data, embedding, loss, torch_utils
|
|
10
|
-
from .config import (
|
|
11
|
-
build_feature_objects,
|
|
12
|
-
build_model_instance,
|
|
13
|
-
extract_feature_groups,
|
|
14
|
-
load_model_class,
|
|
15
|
-
register_processor_features,
|
|
16
|
-
resolve_path,
|
|
17
|
-
safe_value,
|
|
18
|
-
select_features,
|
|
19
|
-
)
|
|
20
|
-
from .console import (
|
|
21
|
-
display_metrics_table,
|
|
22
|
-
get_nextrec_version,
|
|
23
|
-
log_startup_info,
|
|
24
|
-
progress,
|
|
25
|
-
)
|
|
26
|
-
from .data import (
|
|
27
|
-
default_output_dir,
|
|
28
|
-
generate_distributed_ranking_data,
|
|
29
|
-
generate_match_data,
|
|
30
|
-
generate_multitask_data,
|
|
31
|
-
generate_ranking_data,
|
|
32
|
-
iter_file_chunks,
|
|
33
|
-
load_dataframes,
|
|
34
|
-
read_table,
|
|
35
|
-
read_yaml,
|
|
36
|
-
resolve_file_paths,
|
|
37
|
-
)
|
|
38
|
-
from .embedding import get_auto_embedding_dim
|
|
39
|
-
from .torch_utils import as_float, to_list
|
|
40
|
-
from .model import (
|
|
41
|
-
compute_pair_scores,
|
|
42
|
-
get_mlp_output_dim,
|
|
43
|
-
merge_features,
|
|
44
|
-
)
|
|
45
|
-
from .loss import normalize_task_loss
|
|
46
|
-
from .torch_utils import (
|
|
47
|
-
add_distributed_sampler,
|
|
48
|
-
get_device,
|
|
49
|
-
gather_numpy,
|
|
50
|
-
get_initializer,
|
|
51
|
-
get_optimizer,
|
|
52
|
-
get_scheduler,
|
|
53
|
-
init_process_group,
|
|
54
|
-
to_tensor,
|
|
55
|
-
)
|
|
56
|
-
from .types import LossName, OptimizerName, SchedulerName, ActivationName
|
|
57
|
-
|
|
58
|
-
__all__ = [
|
|
59
|
-
# Console utilities
|
|
60
|
-
"get_nextrec_version",
|
|
61
|
-
"log_startup_info",
|
|
62
|
-
"progress",
|
|
63
|
-
"display_metrics_table",
|
|
64
|
-
# Optimizer & Scheduler (torch utils)
|
|
65
|
-
"get_optimizer",
|
|
66
|
-
"get_scheduler",
|
|
67
|
-
# Initializer (torch utils)
|
|
68
|
-
"get_initializer",
|
|
69
|
-
# Embedding utilities
|
|
70
|
-
"get_auto_embedding_dim",
|
|
71
|
-
# Device utilities (torch utils)
|
|
72
|
-
"get_device",
|
|
73
|
-
"init_process_group",
|
|
74
|
-
"gather_numpy",
|
|
75
|
-
"add_distributed_sampler",
|
|
76
|
-
# Tensor utilities
|
|
77
|
-
"to_tensor",
|
|
78
|
-
# Data utilities
|
|
79
|
-
"resolve_file_paths",
|
|
80
|
-
"read_table",
|
|
81
|
-
"read_yaml",
|
|
82
|
-
"load_dataframes",
|
|
83
|
-
"iter_file_chunks",
|
|
84
|
-
"default_output_dir",
|
|
85
|
-
# Model utilities
|
|
86
|
-
"merge_features",
|
|
87
|
-
"get_mlp_output_dim",
|
|
88
|
-
"compute_pair_scores",
|
|
89
|
-
# Loss utilities
|
|
90
|
-
"normalize_task_loss",
|
|
91
|
-
# Feature utilities
|
|
92
|
-
"to_list",
|
|
93
|
-
"as_float",
|
|
94
|
-
# Config utilities
|
|
95
|
-
"resolve_path",
|
|
96
|
-
"safe_value",
|
|
97
|
-
"register_processor_features",
|
|
98
|
-
"build_feature_objects",
|
|
99
|
-
"extract_feature_groups",
|
|
100
|
-
"select_features",
|
|
101
|
-
"load_model_class",
|
|
102
|
-
"build_model_instance",
|
|
103
|
-
# Synthetic data utilities
|
|
104
|
-
"generate_ranking_data",
|
|
105
|
-
"generate_match_data",
|
|
106
|
-
"generate_multitask_data",
|
|
107
|
-
"generate_distributed_ranking_data",
|
|
108
|
-
# Module exports
|
|
109
|
-
"console",
|
|
110
|
-
"data",
|
|
111
|
-
"embedding",
|
|
112
|
-
"loss",
|
|
113
|
-
"torch_utils",
|
|
114
|
-
# Type aliases
|
|
115
|
-
"OptimizerName",
|
|
116
|
-
"SchedulerName",
|
|
117
|
-
"LossName",
|
|
118
|
-
"ActivationName",
|
|
119
|
-
]
|
nextrec/utils/data.py
CHANGED
|
@@ -4,7 +4,7 @@ Data utilities for NextRec.
|
|
|
4
4
|
This module provides file I/O helpers and synthetic data generation.
|
|
5
5
|
|
|
6
6
|
Date: create on 19/12/2025
|
|
7
|
-
Checkpoint: edit on
|
|
7
|
+
Checkpoint: edit on 29/01/2026
|
|
8
8
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
9
9
|
"""
|
|
10
10
|
|
|
@@ -19,46 +19,6 @@ import pyarrow.parquet as pq
|
|
|
19
19
|
import torch
|
|
20
20
|
import yaml
|
|
21
21
|
|
|
22
|
-
FILE_FORMAT_CONFIG = {
|
|
23
|
-
"csv": {
|
|
24
|
-
"extension": [".csv", ".txt"],
|
|
25
|
-
"streaming": True,
|
|
26
|
-
},
|
|
27
|
-
"parquet": {
|
|
28
|
-
"extension": [".parquet"],
|
|
29
|
-
"streaming": True,
|
|
30
|
-
},
|
|
31
|
-
"feather": {
|
|
32
|
-
"extension": [".feather", ".ftr"],
|
|
33
|
-
"streaming": False,
|
|
34
|
-
},
|
|
35
|
-
"excel": {
|
|
36
|
-
"extension": [".xlsx", ".xls"],
|
|
37
|
-
"streaming": False,
|
|
38
|
-
},
|
|
39
|
-
"hdf5": {
|
|
40
|
-
"extension": [".h5", ".hdf5"],
|
|
41
|
-
"streaming": False,
|
|
42
|
-
},
|
|
43
|
-
}
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
def get_file_format_from_extension(ext: str) -> str | None:
|
|
47
|
-
"""Get file format from extension."""
|
|
48
|
-
return {
|
|
49
|
-
ext.lstrip("."): fmt
|
|
50
|
-
for fmt, config in FILE_FORMAT_CONFIG.items()
|
|
51
|
-
for ext in config["extension"]
|
|
52
|
-
}.get(ext.lower().lstrip("."))
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def check_streaming_support(file_format: str) -> bool:
|
|
56
|
-
"""Check if a format supports streaming."""
|
|
57
|
-
file_format = file_format.lower()
|
|
58
|
-
if file_format not in FILE_FORMAT_CONFIG:
|
|
59
|
-
return False
|
|
60
|
-
return FILE_FORMAT_CONFIG[file_format].get("streaming", False)
|
|
61
|
-
|
|
62
22
|
|
|
63
23
|
def resolve_file_paths(path: str) -> tuple[list[str], str]:
|
|
64
24
|
"""
|
|
@@ -70,34 +30,45 @@ def resolve_file_paths(path: str) -> tuple[list[str], str]:
|
|
|
70
30
|
path_obj = Path(path)
|
|
71
31
|
|
|
72
32
|
if path_obj.is_file():
|
|
73
|
-
|
|
33
|
+
name = path_obj.name
|
|
34
|
+
ext = name.rsplit(".", 1)[-1].lower() if "." in name else ""
|
|
35
|
+
if ext in {"csv", "txt"}:
|
|
36
|
+
file_format = "csv"
|
|
37
|
+
elif ext == "parquet":
|
|
38
|
+
file_format = "parquet"
|
|
39
|
+
else:
|
|
40
|
+
file_format = None
|
|
74
41
|
if file_format is None:
|
|
75
42
|
raise ValueError(
|
|
76
|
-
f"Unsupported file extension: {path_obj.suffix}. "
|
|
77
|
-
f"Supported formats: {', '.join(FILE_FORMAT_CONFIG.keys())}"
|
|
43
|
+
f"Unsupported file extension: {path_obj.suffix}. Supported formats: csv, parquet."
|
|
78
44
|
)
|
|
79
45
|
return [str(path_obj)], file_format
|
|
80
46
|
|
|
81
47
|
if path_obj.is_dir():
|
|
82
48
|
collected_files = [p for p in path_obj.iterdir() if p.is_file()]
|
|
83
|
-
|
|
84
|
-
format_groups
|
|
49
|
+
|
|
50
|
+
format_groups = {}
|
|
85
51
|
for file in collected_files:
|
|
86
|
-
|
|
52
|
+
name = file.name
|
|
53
|
+
ext = name.rsplit(".", 1)[-1].lower() if "." in name else ""
|
|
54
|
+
if ext in {"csv", "txt"}:
|
|
55
|
+
file_format = "csv"
|
|
56
|
+
elif ext == "parquet":
|
|
57
|
+
file_format = "parquet"
|
|
58
|
+
else:
|
|
59
|
+
file_format = None
|
|
87
60
|
if file_format:
|
|
88
61
|
format_groups.setdefault(file_format, []).append(str(file))
|
|
89
62
|
|
|
90
63
|
if len(format_groups) > 1:
|
|
91
64
|
formats = ", ".join(format_groups.keys())
|
|
92
65
|
raise ValueError(
|
|
93
|
-
f"Directory contains mixed file formats: {formats}. "
|
|
94
|
-
"Please keep a single format per directory."
|
|
66
|
+
f"Directory contains mixed file formats: {formats}. Please keep a single format per directory."
|
|
95
67
|
)
|
|
96
68
|
|
|
97
69
|
if not format_groups:
|
|
98
70
|
raise ValueError(
|
|
99
|
-
f"No supported data files found in directory: {path}. "
|
|
100
|
-
f"Supported formats: {', '.join(FILE_FORMAT_CONFIG.keys())}"
|
|
71
|
+
f"No supported data files found in directory: {path}. Supported formats: csv, parquet."
|
|
101
72
|
)
|
|
102
73
|
|
|
103
74
|
file_type = list(format_groups.keys())[0]
|
|
@@ -111,18 +82,14 @@ def resolve_file_paths(path: str) -> tuple[list[str], str]:
|
|
|
111
82
|
def read_table(path: str | Path, data_format: str | None = None) -> pd.DataFrame:
|
|
112
83
|
data_path = Path(path)
|
|
113
84
|
|
|
114
|
-
# Determine format
|
|
115
85
|
if data_format:
|
|
116
|
-
fmt = data_format
|
|
86
|
+
fmt = data_format
|
|
117
87
|
elif data_path.is_dir():
|
|
118
88
|
_, fmt = resolve_file_paths(str(data_path))
|
|
119
89
|
else:
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
f"Cannot determine format for {data_path}. "
|
|
124
|
-
f"Please specify data_format parameter."
|
|
125
|
-
)
|
|
90
|
+
raise ValueError(
|
|
91
|
+
f"Cannot determine format for {data_path}. Please specify data_format parameter."
|
|
92
|
+
)
|
|
126
93
|
|
|
127
94
|
if data_path.is_dir():
|
|
128
95
|
file_paths, _ = resolve_file_paths(str(data_path))
|
|
@@ -133,36 +100,11 @@ def read_table(path: str | Path, data_format: str | None = None) -> pd.DataFrame
|
|
|
133
100
|
return dataframes[0]
|
|
134
101
|
return pd.concat(dataframes, ignore_index=True)
|
|
135
102
|
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
if len(store.keys()) == 0:
|
|
142
|
-
raise ValueError(f"HDF5 file {data_path} contains no datasets")
|
|
143
|
-
return pd.read_hdf(data_path, key=store.keys()[0])
|
|
144
|
-
reader = {
|
|
145
|
-
"parquet": pd.read_parquet,
|
|
146
|
-
"csv": lambda p: pd.read_csv(p, low_memory=False),
|
|
147
|
-
"feather": pd.read_feather,
|
|
148
|
-
"excel": pd.read_excel,
|
|
149
|
-
}.get(fmt)
|
|
150
|
-
if reader:
|
|
151
|
-
return reader(data_path)
|
|
152
|
-
raise ValueError(
|
|
153
|
-
f"Unsupported format: {fmt}. "
|
|
154
|
-
f"Supported: {', '.join(FILE_FORMAT_CONFIG.keys())}"
|
|
155
|
-
)
|
|
156
|
-
except ImportError as e:
|
|
157
|
-
raise ImportError(
|
|
158
|
-
f"Format '{fmt}' requires additional dependencies. "
|
|
159
|
-
f"Install with: pip install pandas[{fmt}] or check documentation. "
|
|
160
|
-
f"Original error: {e}"
|
|
161
|
-
) from e
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
def load_dataframes(file_paths: list[str], file_type: str) -> list[pd.DataFrame]:
|
|
165
|
-
return [read_table(fp, file_type) for fp in file_paths]
|
|
103
|
+
if fmt == "parquet":
|
|
104
|
+
return pd.read_parquet(data_path)
|
|
105
|
+
if fmt == "csv":
|
|
106
|
+
return pd.read_csv(data_path, low_memory=False)
|
|
107
|
+
raise ValueError(f"Unsupported format: {fmt}.")
|
|
166
108
|
|
|
167
109
|
|
|
168
110
|
def iter_file_chunks(
|
|
@@ -182,37 +124,17 @@ def iter_file_chunks(
|
|
|
182
124
|
ValueError: If format doesn't support streaming
|
|
183
125
|
"""
|
|
184
126
|
file_type = file_type.lower()
|
|
185
|
-
if not
|
|
127
|
+
if file_type not in {"csv", "parquet"}:
|
|
186
128
|
raise ValueError(
|
|
187
|
-
f"Format '{file_type}' does not support streaming reads. "
|
|
188
|
-
"Formats with streaming support: csv, parquet"
|
|
129
|
+
f"Format '{file_type}' does not support streaming reads. Formats with streaming support: csv, parquet"
|
|
189
130
|
)
|
|
190
131
|
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
yield batch.to_pandas()
|
|
198
|
-
else:
|
|
199
|
-
raise ValueError(
|
|
200
|
-
f"Format '{file_type}' does not support streaming. "
|
|
201
|
-
f"Use read_table() to load the entire file into memory."
|
|
202
|
-
)
|
|
203
|
-
except ImportError as e:
|
|
204
|
-
raise ImportError(
|
|
205
|
-
f"Streaming format '{file_type}' requires additional dependencies. "
|
|
206
|
-
f"Install with: pip install pandas[{file_type}] pyarrow. "
|
|
207
|
-
f"Original error: {e}"
|
|
208
|
-
) from e
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
def default_output_dir(path: str) -> Path:
|
|
212
|
-
path_obj = Path(path)
|
|
213
|
-
if path_obj.is_file():
|
|
214
|
-
return path_obj.parent / f"{path_obj.stem}_preprocessed"
|
|
215
|
-
return path_obj.with_name(f"{path_obj.name}_preprocessed")
|
|
132
|
+
if file_type == "csv":
|
|
133
|
+
yield from pd.read_csv(file_path, chunksize=chunk_size)
|
|
134
|
+
elif file_type == "parquet":
|
|
135
|
+
parquet_file = pq.ParquetFile(file_path)
|
|
136
|
+
for batch in parquet_file.iter_batches(batch_size=chunk_size):
|
|
137
|
+
yield batch.to_pandas()
|
|
216
138
|
|
|
217
139
|
|
|
218
140
|
def read_yaml(path: str | Path):
|
|
@@ -232,7 +154,6 @@ def generate_ranking_data(
|
|
|
232
154
|
embedding_dim: int = 16,
|
|
233
155
|
seed: int = 42,
|
|
234
156
|
custom_sparse_features: Optional[Dict[str, int]] = None,
|
|
235
|
-
use_simple_names: bool = True,
|
|
236
157
|
) -> Tuple[pd.DataFrame, List, List, List]:
|
|
237
158
|
"""
|
|
238
159
|
Generate synthetic data for ranking tasks (CTR prediction)
|
|
@@ -737,7 +658,6 @@ def generate_distributed_ranking_data(
|
|
|
737
658
|
"category": num_categories,
|
|
738
659
|
"city": num_cities,
|
|
739
660
|
},
|
|
740
|
-
use_simple_names=False,
|
|
741
661
|
)
|
|
742
662
|
|
|
743
663
|
|
nextrec/utils/model.py
CHANGED
|
@@ -2,24 +2,20 @@
|
|
|
2
2
|
Model-related utilities for NextRec
|
|
3
3
|
|
|
4
4
|
Date: create on 03/12/2025
|
|
5
|
-
Checkpoint: edit on 31/
|
|
5
|
+
Checkpoint: edit on 31/01/2026
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
-
from collections import OrderedDict
|
|
10
|
-
|
|
11
9
|
import torch
|
|
12
10
|
import torch.nn as nn
|
|
13
11
|
|
|
14
|
-
from nextrec.loss import (
|
|
12
|
+
from nextrec.loss.listwise import (
|
|
15
13
|
ApproxNDCGLoss,
|
|
16
|
-
BPRLoss,
|
|
17
|
-
HingeLoss,
|
|
18
14
|
ListMLELoss,
|
|
19
15
|
ListNetLoss,
|
|
20
16
|
SampledSoftmaxLoss,
|
|
21
|
-
TripletLoss,
|
|
22
17
|
)
|
|
18
|
+
from nextrec.loss.pairwise import BPRLoss, HingeLoss, TripletLoss
|
|
23
19
|
|
|
24
20
|
from nextrec.utils.types import (
|
|
25
21
|
LossName,
|
|
@@ -27,13 +23,6 @@ from nextrec.utils.types import (
|
|
|
27
23
|
)
|
|
28
24
|
|
|
29
25
|
|
|
30
|
-
def merge_features(primary, secondary) -> list:
|
|
31
|
-
merged: OrderedDict[str, object] = OrderedDict()
|
|
32
|
-
for feat in list(primary or []) + list(secondary or []):
|
|
33
|
-
merged.setdefault(feat.name, feat)
|
|
34
|
-
return list(merged.values())
|
|
35
|
-
|
|
36
|
-
|
|
37
26
|
def get_mlp_output_dim(params: dict, fallback: int) -> int:
|
|
38
27
|
hidden_dims = params.get("hidden_dims")
|
|
39
28
|
if hidden_dims:
|
|
@@ -46,6 +35,8 @@ def select_features(
|
|
|
46
35
|
names: list[str],
|
|
47
36
|
param_name: str,
|
|
48
37
|
) -> list:
|
|
38
|
+
"""select features by names from available features."""
|
|
39
|
+
|
|
49
40
|
if not names:
|
|
50
41
|
return []
|
|
51
42
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nextrec
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.2
|
|
4
4
|
Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
|
|
5
5
|
Project-URL: Homepage, https://github.com/zerolovesea/NextRec
|
|
6
6
|
Project-URL: Repository, https://github.com/zerolovesea/NextRec
|
|
@@ -73,7 +73,7 @@ Description-Content-Type: text/markdown
|
|
|
73
73
|

|
|
74
74
|

|
|
75
75
|

|
|
76
|
-

|
|
77
77
|
[](https://deepwiki.com/zerolovesea/NextRec)
|
|
78
78
|
|
|
79
79
|
中文文档 | [English Version](README_en.md)
|
|
@@ -112,7 +112,6 @@ NextRec是一个基于PyTorch的现代推荐系统框架,旨在为研究工程
|
|
|
112
112
|
- **21/12/2025** 在v0.4.16中加入了对[GradNorm](/nextrec/loss/grad_norm.py)的支持,通过compile的`loss_weight='grad_norm'`进行配置
|
|
113
113
|
- **12/12/2025** 在v0.4.9中加入了[RQ-VAE](/nextrec/models/representation/rqvae.py)模块。配套的[数据集](/dataset/ecommerce_task.csv)和[代码](tutorials/notebooks/zh/使用RQ-VAE构建语义ID.ipynb)已经同步在仓库中
|
|
114
114
|
- **07/12/2025** 发布了NextRec CLI命令行工具,它允许用户根据配置文件进行一键训练和推理,我们提供了相关的[教程](/nextrec_cli_preset/NextRec-CLI_zh.md)和[教学代码](/nextrec_cli_preset)
|
|
115
|
-
- **03/12/2025** NextRec获得了100颗🌟!感谢大家的支持
|
|
116
115
|
- **06/12/2025** 在v0.4.1中支持了单机多卡的分布式DDP训练,并且提供了配套的[代码](tutorials/distributed)
|
|
117
116
|
- **11/11/2025** NextRec v0.1.0发布,我们提供了10余种Ranking模型,11种多任务模型和4种召回模型,以及统一的训练/日志/指标管理系统
|
|
118
117
|
|
|
@@ -260,11 +259,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
|
|
|
260
259
|
|
|
261
260
|
预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
|
|
262
261
|
|
|
263
|
-
> 截止当前版本0.5.
|
|
262
|
+
> 截止当前版本0.5.2,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
|
|
264
263
|
|
|
265
264
|
## 兼容平台
|
|
266
265
|
|
|
267
|
-
当前最新版本为0.5.
|
|
266
|
+
当前最新版本为0.5.2,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
|
|
268
267
|
|
|
269
268
|
| 平台 | 配置 |
|
|
270
269
|
|------|------|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
nextrec/__init__.py,sha256=_M3oUqyuvQ5k8Th_3wId6hQ_caclh7M5ad51XN09m98,235
|
|
2
|
-
nextrec/__version__.py,sha256=
|
|
3
|
-
nextrec/cli.py,sha256=
|
|
2
|
+
nextrec/__version__.py,sha256=isJrmDBLRag7Zc2UK9ZovWGOv7ji1Oh-zJtJMNJFkXw,22
|
|
3
|
+
nextrec/cli.py,sha256=ddYlk2zXF7iVFpONhG9Ofwb9dVgXfy112nk_JOHy1kI,29840
|
|
4
4
|
nextrec/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
5
|
nextrec/basic/activation.py,sha256=rU-W-DHgiD3AZnMGmD014ChxklfP9BpedDTiwtdgXhA,2762
|
|
6
6
|
nextrec/basic/asserts.py,sha256=eaB4FZJ7Sbh9S8PLJZNiYd7Zi98ca2rDi4S7wSYCaEw,1473
|
|
@@ -10,21 +10,21 @@ nextrec/basic/heads.py,sha256=WqvavaH6Y8Au8dLaoUfH2AaOOWgYvjZI5US8avkQNsQ,4009
|
|
|
10
10
|
nextrec/basic/layers.py,sha256=tawggQMMHlYTGpnubxUAvDPDJe_Lpq-HpLLCSjbJV54,37320
|
|
11
11
|
nextrec/basic/loggers.py,sha256=jJUzUt_kMpjpV2Mqr7qBENWA1olEutTI7dFnpmndUUw,13845
|
|
12
12
|
nextrec/basic/metrics.py,sha256=nVz3AkKwsxj_M97CoZWyQj8_Y9ZM_Icvw_QCM6c33Bc,26262
|
|
13
|
-
nextrec/basic/model.py,sha256=
|
|
13
|
+
nextrec/basic/model.py,sha256=Fp1nRVd-8xq6wgjny02afGdIPUhJ-x1jD9bIPHserjQ,135037
|
|
14
14
|
nextrec/basic/session.py,sha256=mrIsjRJhmvcAfoO1pXX-KB3SK5CCgz89wH8XDoAiGEI,4475
|
|
15
|
-
nextrec/basic/summary.py,sha256=
|
|
16
|
-
nextrec/data/__init__.py,sha256=
|
|
17
|
-
nextrec/data/batch_utils.py,sha256=
|
|
18
|
-
nextrec/data/data_processing.py,sha256=
|
|
19
|
-
nextrec/data/data_utils.py,sha256=
|
|
20
|
-
nextrec/data/dataloader.py,sha256=
|
|
21
|
-
nextrec/data/preprocessor.py,sha256=
|
|
22
|
-
nextrec/loss/__init__.py,sha256=
|
|
15
|
+
nextrec/basic/summary.py,sha256=r5d8VtxJsgY6WRAVRa2-UcGb_sYA9VjMKDfQ31928qM,20494
|
|
16
|
+
nextrec/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
17
|
+
nextrec/data/batch_utils.py,sha256=Gbeo6XxP2ZsAAGP6f74MYQth-X_EDnXxkuMIGMr2bPk,3537
|
|
18
|
+
nextrec/data/data_processing.py,sha256=TQdvqBOy8_h2TSjMFibCnWcoqtgyFV4TeLyFC621Xb8,5413
|
|
19
|
+
nextrec/data/data_utils.py,sha256=RBpk2ymWJQHtNaa_yAnRk6bkck2UzazKfPSQsRTR1tU,786
|
|
20
|
+
nextrec/data/dataloader.py,sha256=4fAqHZFSbnKOtXWrXLJ39zYfmTAa-dpySyFON1PlWfk,20503
|
|
21
|
+
nextrec/data/preprocessor.py,sha256=92r7c3B_3BSpCmEqI9gXywTlS319kxf4qqzkxHjBfVk,51350
|
|
22
|
+
nextrec/loss/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
23
23
|
nextrec/loss/grad_norm.py,sha256=I4jAs0f84I7MWmYZOMC0JRUNvBHZzhgpuox0hOtYWDg,7435
|
|
24
24
|
nextrec/loss/listwise.py,sha256=mluxXQt9XiuWGvXA1nk4I0miqaKB6_GPVQqxLhAiJKs,5999
|
|
25
25
|
nextrec/loss/pairwise.py,sha256=9fyH9p2u-N0-jAnNTq3X5Dje0ipj1dob8wp-yQKRra4,3493
|
|
26
26
|
nextrec/loss/pointwise.py,sha256=09nzI1L5eP9raXnj3Q49bD9Clp_JmsSWUvEj7bkTzSw,7474
|
|
27
|
-
nextrec/models/generative/__init__.py,sha256=
|
|
27
|
+
nextrec/models/generative/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
28
28
|
nextrec/models/generative/tiger.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
29
29
|
nextrec/models/multi_task/[pre]aitm.py,sha256=A2n0T4JEui-uHgbqITU5lpsmtnP14fQXRZM1peTPvhQ,6661
|
|
30
30
|
nextrec/models/multi_task/[pre]snr_trans.py,sha256=k08tC-TI--a_Tt4_BmX0ZubzntyqwsejutYzbB5F4S4,9077
|
|
@@ -65,23 +65,23 @@ nextrec/models/retrieval/mind.py,sha256=I0qVj39ApweRGW3qDNLca5vsNtJwRe7gBLh1peds
|
|
|
65
65
|
nextrec/models/retrieval/sdm.py,sha256=h9TqVmSJ8YF7hgPci784nAlBg1LazB641c4iEeuiLDg,9956
|
|
66
66
|
nextrec/models/retrieval/youtube_dnn.py,sha256=hLyR4liuusJIjRg4vuaSoSEecYgDICipXnNFiA3o3oY,6351
|
|
67
67
|
nextrec/models/sequential/hstu.py,sha256=XFq-IERFg2ohqg03HkP6YinQaZUXljtYayUmvU-N_IY,18916
|
|
68
|
-
nextrec/models/tree_base/__init__.py,sha256=
|
|
69
|
-
nextrec/models/tree_base/base.py,sha256=
|
|
68
|
+
nextrec/models/tree_base/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
69
|
+
nextrec/models/tree_base/base.py,sha256=czmIrlbGGMjiwwTcj_nVP5-Bnqrl_VNdh0YGHeRaejk,26101
|
|
70
70
|
nextrec/models/tree_base/catboost.py,sha256=hXINyx7iianwDxOZx3SLm0i-YP1jiC3HcAeqP9A2i4A,3434
|
|
71
71
|
nextrec/models/tree_base/lightgbm.py,sha256=VilMU7SgfHR5LAaaoQo-tY1vkzpSvWovIrgaSeuJ1-A,2263
|
|
72
72
|
nextrec/models/tree_base/xgboost.py,sha256=thOmDIC_nitno_k2mcH2cj2VcS07f9veTG01FMOO-28,1957
|
|
73
|
-
nextrec/utils/__init__.py,sha256=
|
|
73
|
+
nextrec/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
74
74
|
nextrec/utils/config.py,sha256=Ngd_u8ZS5br4lIqrBJ_ecLquMF4KJi6TPAGqLZg8H4s,20485
|
|
75
75
|
nextrec/utils/console.py,sha256=RnSUplJnyanSQ6TyMQkP7S1j2rGMver1DbFVqNH6_1k,13581
|
|
76
|
-
nextrec/utils/data.py,sha256=
|
|
76
|
+
nextrec/utils/data.py,sha256=JN30npycFMo_caBQJ5815tciRW3r7YDqSQxnJuXhlVA,22299
|
|
77
77
|
nextrec/utils/embedding.py,sha256=akAEc062MG2cD7VIOllHaqtwzAirQR2gq5iW7oKpGAU,1449
|
|
78
78
|
nextrec/utils/loss.py,sha256=GBWQGpDaYkMJySpdG078XbeUNXUC34PVqFy0AqNS9N0,4578
|
|
79
|
-
nextrec/utils/model.py,sha256=
|
|
79
|
+
nextrec/utils/model.py,sha256=K1-ZfS1umSTVo-lNHKNkBb3xvWnfiJJSTRicSVDpk4s,5148
|
|
80
80
|
nextrec/utils/onnx_utils.py,sha256=KIVV_ELYzj3kCswfsSBZ1F2OnSwRJnXj7sxDBwBoBaA,8668
|
|
81
81
|
nextrec/utils/torch_utils.py,sha256=_a9e6GXa3QKuu0E5RL44QRZ1iJSobbtNcPB3vtaCsu8,12313
|
|
82
82
|
nextrec/utils/types.py,sha256=LFwYCBRo5WeYUh5LSCuyP1Lg9ez0Ih00Es3fUttGAFw,2273
|
|
83
|
-
nextrec-0.5.
|
|
84
|
-
nextrec-0.5.
|
|
85
|
-
nextrec-0.5.
|
|
86
|
-
nextrec-0.5.
|
|
87
|
-
nextrec-0.5.
|
|
83
|
+
nextrec-0.5.2.dist-info/METADATA,sha256=ibFYpr07uouyfvIqXngb9GcezGwZe3gIK-E3lbNtavY,23464
|
|
84
|
+
nextrec-0.5.2.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
85
|
+
nextrec-0.5.2.dist-info/entry_points.txt,sha256=NN-dNSdfMRTv86bNXM7d3ZEPW2BQC6bRi7QP7i9cIps,45
|
|
86
|
+
nextrec-0.5.2.dist-info/licenses/LICENSE,sha256=COP1BsqnEUwdx6GCkMjxOo5v3pUe4-Go_CdmQmSfYXM,1064
|
|
87
|
+
nextrec-0.5.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|