nextrec 0.5.3__py3-none-any.whl → 0.5.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/model.py +38 -16
- nextrec/cli.py +6 -8
- nextrec/utils/config.py +5 -1
- {nextrec-0.5.3.dist-info → nextrec-0.5.5.dist-info}/METADATA +4 -4
- {nextrec-0.5.3.dist-info → nextrec-0.5.5.dist-info}/RECORD +9 -9
- {nextrec-0.5.3.dist-info → nextrec-0.5.5.dist-info}/WHEEL +0 -0
- {nextrec-0.5.3.dist-info → nextrec-0.5.5.dist-info}/entry_points.txt +0 -0
- {nextrec-0.5.3.dist-info → nextrec-0.5.5.dist-info}/licenses/LICENSE +0 -0
nextrec/__version__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.5.
|
|
1
|
+
__version__ = "0.5.5"
|
nextrec/basic/model.py
CHANGED
|
@@ -2099,6 +2099,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
2099
2099
|
]
|
|
2100
2100
|
|
|
2101
2101
|
ctx = mp.get_context("spawn")
|
|
2102
|
+
error_queue = ctx.SimpleQueue()
|
|
2102
2103
|
processes = []
|
|
2103
2104
|
for rank in range(num_processes):
|
|
2104
2105
|
process = ctx.Process(
|
|
@@ -2115,6 +2116,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
2115
2116
|
processor,
|
|
2116
2117
|
rank,
|
|
2117
2118
|
num_processes,
|
|
2119
|
+
error_queue,
|
|
2118
2120
|
),
|
|
2119
2121
|
)
|
|
2120
2122
|
process.start()
|
|
@@ -2127,8 +2129,18 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
2127
2129
|
|
|
2128
2130
|
for process in processes:
|
|
2129
2131
|
if process.exitcode not in (0, None):
|
|
2132
|
+
errors = []
|
|
2133
|
+
try:
|
|
2134
|
+
while not error_queue.empty():
|
|
2135
|
+
errors.append(error_queue.get_nowait())
|
|
2136
|
+
except Exception:
|
|
2137
|
+
pass
|
|
2138
|
+
error_text = (
|
|
2139
|
+
"\n\n".join(errors) if errors else "No worker traceback captured."
|
|
2140
|
+
)
|
|
2130
2141
|
raise RuntimeError(
|
|
2131
|
-
"[BaseModel-predict-streaming Error] One or more inference processes failed
|
|
2142
|
+
"[BaseModel-predict-streaming Error] One or more inference processes failed.\n"
|
|
2143
|
+
+ error_text
|
|
2132
2144
|
)
|
|
2133
2145
|
# Merge part files
|
|
2134
2146
|
existing_parts = [p for p in part_paths if p.exists()]
|
|
@@ -2827,22 +2839,32 @@ def predict_streaming_worker(
|
|
|
2827
2839
|
processor: Any | None,
|
|
2828
2840
|
shard_rank: int,
|
|
2829
2841
|
shard_count: int,
|
|
2842
|
+
error_queue: "mp.SimpleQueue[str] | None" = None,
|
|
2830
2843
|
) -> None:
|
|
2831
|
-
|
|
2832
|
-
|
|
2833
|
-
|
|
2834
|
-
|
|
2835
|
-
|
|
2836
|
-
|
|
2837
|
-
|
|
2838
|
-
|
|
2839
|
-
|
|
2840
|
-
|
|
2841
|
-
|
|
2842
|
-
|
|
2843
|
-
|
|
2844
|
-
|
|
2845
|
-
|
|
2844
|
+
try:
|
|
2845
|
+
model.eval()
|
|
2846
|
+
model.predict_streaming(
|
|
2847
|
+
data=data_path,
|
|
2848
|
+
batch_size=batch_size,
|
|
2849
|
+
save_path=save_path,
|
|
2850
|
+
save_format=save_format,
|
|
2851
|
+
stream_chunk_size=stream_chunk_size,
|
|
2852
|
+
return_dataframe=False,
|
|
2853
|
+
num_workers=num_workers,
|
|
2854
|
+
prefetch_factor=prefetch_factor,
|
|
2855
|
+
processor=processor,
|
|
2856
|
+
num_processes=1,
|
|
2857
|
+
shard_rank=shard_rank,
|
|
2858
|
+
shard_count=shard_count,
|
|
2859
|
+
)
|
|
2860
|
+
except Exception:
|
|
2861
|
+
if error_queue is not None:
|
|
2862
|
+
import traceback
|
|
2863
|
+
|
|
2864
|
+
error_queue.put(
|
|
2865
|
+
f"[PredictWorker Error] rank={shard_rank}\n{traceback.format_exc()}"
|
|
2866
|
+
)
|
|
2867
|
+
raise
|
|
2846
2868
|
|
|
2847
2869
|
|
|
2848
2870
|
class BaseMatchModel(BaseModel):
|
nextrec/cli.py
CHANGED
|
@@ -25,7 +25,7 @@ import resource
|
|
|
25
25
|
import sys
|
|
26
26
|
import time
|
|
27
27
|
from pathlib import Path
|
|
28
|
-
from typing import Any, Dict
|
|
28
|
+
from typing import Any, Dict
|
|
29
29
|
|
|
30
30
|
import pandas as pd
|
|
31
31
|
|
|
@@ -137,7 +137,7 @@ def train_model(train_config_path: str) -> None:
|
|
|
137
137
|
model_cfg = read_yaml(model_cfg_path)
|
|
138
138
|
|
|
139
139
|
# Extract id_column from data config for GAUC metrics
|
|
140
|
-
id_column = data_cfg.get("id_column")
|
|
140
|
+
id_column = data_cfg.get("id_column")
|
|
141
141
|
id_columns = [id_column] if id_column else []
|
|
142
142
|
|
|
143
143
|
log_cli_section("Data")
|
|
@@ -378,6 +378,7 @@ def train_model(train_config_path: str) -> None:
|
|
|
378
378
|
sparse_features,
|
|
379
379
|
sequence_features,
|
|
380
380
|
target,
|
|
381
|
+
id_columns,
|
|
381
382
|
device,
|
|
382
383
|
)
|
|
383
384
|
|
|
@@ -585,19 +586,16 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
585
586
|
sparse_features=sparse_features,
|
|
586
587
|
sequence_features=sequence_features,
|
|
587
588
|
target=target_cols,
|
|
589
|
+
id_columns=id_columns,
|
|
588
590
|
device=predict_cfg.get("device", "cpu"),
|
|
589
591
|
)
|
|
590
|
-
|
|
592
|
+
|
|
591
593
|
model.load_model(
|
|
592
594
|
model_file, map_location=predict_cfg.get("device", "cpu"), verbose=True
|
|
593
595
|
)
|
|
594
596
|
|
|
595
|
-
id_columns = []
|
|
596
|
-
if predict_cfg.get("id_column"):
|
|
597
|
-
id_columns = [predict_cfg["id_column"]]
|
|
598
|
-
model.id_columns = id_columns
|
|
599
|
-
|
|
600
597
|
effective_id_columns = id_columns or model.id_columns
|
|
598
|
+
|
|
601
599
|
log_cli_section("Features")
|
|
602
600
|
log_kv_lines(
|
|
603
601
|
[
|
nextrec/utils/config.py
CHANGED
|
@@ -388,6 +388,7 @@ def build_model_instance(
|
|
|
388
388
|
sparse_features: List["SparseFeature"],
|
|
389
389
|
sequence_features: List["SequenceFeature"],
|
|
390
390
|
target: List[str],
|
|
391
|
+
id_columns: list[str] | str | None,
|
|
391
392
|
device: str,
|
|
392
393
|
) -> Any:
|
|
393
394
|
"""
|
|
@@ -400,6 +401,7 @@ def build_model_instance(
|
|
|
400
401
|
sparse_features: List of sparse feature objects
|
|
401
402
|
sequence_features: List of sequence feature objects
|
|
402
403
|
target: List of target column names
|
|
404
|
+
id_columns: Identifier column name(s) for GAUC or ID passthrough
|
|
403
405
|
device: Device string (e.g., 'cpu', 'cuda:0')
|
|
404
406
|
"""
|
|
405
407
|
dense_map = {f.name: f for f in dense_features}
|
|
@@ -436,7 +438,7 @@ def build_model_instance(
|
|
|
436
438
|
param.kind == inspect.Parameter.VAR_KEYWORD for param in sig_params.values()
|
|
437
439
|
)
|
|
438
440
|
|
|
439
|
-
init_kwargs
|
|
441
|
+
init_kwargs = dict(params_cfg)
|
|
440
442
|
|
|
441
443
|
# Explicit bindings (model_config.feature_bindings) take priority
|
|
442
444
|
for param_name, binding in feature_bindings_cfg.items():
|
|
@@ -517,6 +519,8 @@ def build_model_instance(
|
|
|
517
519
|
|
|
518
520
|
if accepts("target"):
|
|
519
521
|
init_kwargs.setdefault("target", target)
|
|
522
|
+
if accepts("id_columns") or accepts_var_kwargs:
|
|
523
|
+
init_kwargs.setdefault("id_columns", id_columns)
|
|
520
524
|
if accepts("device"):
|
|
521
525
|
init_kwargs.setdefault("device", device)
|
|
522
526
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nextrec
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.5
|
|
4
4
|
Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
|
|
5
5
|
Project-URL: Homepage, https://github.com/zerolovesea/NextRec
|
|
6
6
|
Project-URL: Repository, https://github.com/zerolovesea/NextRec
|
|
@@ -73,7 +73,7 @@ Description-Content-Type: text/markdown
|
|
|
73
73
|

|
|
74
74
|

|
|
75
75
|

|
|
76
|
-

|
|
77
77
|
[](https://deepwiki.com/zerolovesea/NextRec)
|
|
78
78
|
|
|
79
79
|
中文文档 | [English Version](README_en.md)
|
|
@@ -259,11 +259,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
|
|
|
259
259
|
|
|
260
260
|
预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
|
|
261
261
|
|
|
262
|
-
> 截止当前版本0.5.
|
|
262
|
+
> 截止当前版本0.5.5,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
|
|
263
263
|
|
|
264
264
|
## 兼容平台
|
|
265
265
|
|
|
266
|
-
当前最新版本为0.5.
|
|
266
|
+
当前最新版本为0.5.5,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
|
|
267
267
|
|
|
268
268
|
| 平台 | 配置 |
|
|
269
269
|
|------|------|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
nextrec/__init__.py,sha256=_M3oUqyuvQ5k8Th_3wId6hQ_caclh7M5ad51XN09m98,235
|
|
2
|
-
nextrec/__version__.py,sha256=
|
|
3
|
-
nextrec/cli.py,sha256=
|
|
2
|
+
nextrec/__version__.py,sha256=78mfpLewKVki6c9UONSUdlVme_JsN9ZwIfp4Hf4jmG0,22
|
|
3
|
+
nextrec/cli.py,sha256=ywGOMJ-iVLH6bOk2VX_n0r530V11xbFUHVwu6-AlzPw,29675
|
|
4
4
|
nextrec/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
5
|
nextrec/basic/activation.py,sha256=rU-W-DHgiD3AZnMGmD014ChxklfP9BpedDTiwtdgXhA,2762
|
|
6
6
|
nextrec/basic/asserts.py,sha256=eaB4FZJ7Sbh9S8PLJZNiYd7Zi98ca2rDi4S7wSYCaEw,1473
|
|
@@ -10,7 +10,7 @@ nextrec/basic/heads.py,sha256=WqvavaH6Y8Au8dLaoUfH2AaOOWgYvjZI5US8avkQNsQ,4009
|
|
|
10
10
|
nextrec/basic/layers.py,sha256=tawggQMMHlYTGpnubxUAvDPDJe_Lpq-HpLLCSjbJV54,37320
|
|
11
11
|
nextrec/basic/loggers.py,sha256=jJUzUt_kMpjpV2Mqr7qBENWA1olEutTI7dFnpmndUUw,13845
|
|
12
12
|
nextrec/basic/metrics.py,sha256=nVz3AkKwsxj_M97CoZWyQj8_Y9ZM_Icvw_QCM6c33Bc,26262
|
|
13
|
-
nextrec/basic/model.py,sha256=
|
|
13
|
+
nextrec/basic/model.py,sha256=J8fkwQF2Y5YUajqucfSpsNuz8sbLqsE3mupCfbkZDhc,135869
|
|
14
14
|
nextrec/basic/session.py,sha256=mrIsjRJhmvcAfoO1pXX-KB3SK5CCgz89wH8XDoAiGEI,4475
|
|
15
15
|
nextrec/basic/summary.py,sha256=r5d8VtxJsgY6WRAVRa2-UcGb_sYA9VjMKDfQ31928qM,20494
|
|
16
16
|
nextrec/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -71,7 +71,7 @@ nextrec/models/tree_base/catboost.py,sha256=hXINyx7iianwDxOZx3SLm0i-YP1jiC3HcAeq
|
|
|
71
71
|
nextrec/models/tree_base/lightgbm.py,sha256=VilMU7SgfHR5LAaaoQo-tY1vkzpSvWovIrgaSeuJ1-A,2263
|
|
72
72
|
nextrec/models/tree_base/xgboost.py,sha256=thOmDIC_nitno_k2mcH2cj2VcS07f9veTG01FMOO-28,1957
|
|
73
73
|
nextrec/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
74
|
-
nextrec/utils/config.py,sha256=
|
|
74
|
+
nextrec/utils/config.py,sha256=uChqdxA3X6gn2G8tlUrobZ1AbaxdSQMt87eocSQVgog,20691
|
|
75
75
|
nextrec/utils/console.py,sha256=RnSUplJnyanSQ6TyMQkP7S1j2rGMver1DbFVqNH6_1k,13581
|
|
76
76
|
nextrec/utils/data.py,sha256=JN30npycFMo_caBQJ5815tciRW3r7YDqSQxnJuXhlVA,22299
|
|
77
77
|
nextrec/utils/embedding.py,sha256=akAEc062MG2cD7VIOllHaqtwzAirQR2gq5iW7oKpGAU,1449
|
|
@@ -80,8 +80,8 @@ nextrec/utils/model.py,sha256=K1-ZfS1umSTVo-lNHKNkBb3xvWnfiJJSTRicSVDpk4s,5148
|
|
|
80
80
|
nextrec/utils/onnx_utils.py,sha256=KIVV_ELYzj3kCswfsSBZ1F2OnSwRJnXj7sxDBwBoBaA,8668
|
|
81
81
|
nextrec/utils/torch_utils.py,sha256=_a9e6GXa3QKuu0E5RL44QRZ1iJSobbtNcPB3vtaCsu8,12313
|
|
82
82
|
nextrec/utils/types.py,sha256=LFwYCBRo5WeYUh5LSCuyP1Lg9ez0Ih00Es3fUttGAFw,2273
|
|
83
|
-
nextrec-0.5.
|
|
84
|
-
nextrec-0.5.
|
|
85
|
-
nextrec-0.5.
|
|
86
|
-
nextrec-0.5.
|
|
87
|
-
nextrec-0.5.
|
|
83
|
+
nextrec-0.5.5.dist-info/METADATA,sha256=XbDqdm6X1jgRyk_GSMXyMVKtZe1AsuCNTNjcC4NTD-A,23464
|
|
84
|
+
nextrec-0.5.5.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
85
|
+
nextrec-0.5.5.dist-info/entry_points.txt,sha256=NN-dNSdfMRTv86bNXM7d3ZEPW2BQC6bRi7QP7i9cIps,45
|
|
86
|
+
nextrec-0.5.5.dist-info/licenses/LICENSE,sha256=COP1BsqnEUwdx6GCkMjxOo5v3pUe4-Go_CdmQmSfYXM,1064
|
|
87
|
+
nextrec-0.5.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|