nextrec 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nextrec/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.5.2"
1
+ __version__ = "0.5.4"
nextrec/basic/model.py CHANGED
@@ -1707,7 +1707,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1707
1707
  if (
1708
1708
  save_path is not None
1709
1709
  and not return_dataframe
1710
- and isinstance(data, (str, os.PathLike))
1710
+ and isinstance(data, (str, os.PathLike, DataLoader))
1711
1711
  ):
1712
1712
  if num_processes > 1 and not isinstance(data, (str, os.PathLike)):
1713
1713
  raise ValueError(
@@ -2099,6 +2099,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
2099
2099
  ]
2100
2100
 
2101
2101
  ctx = mp.get_context("spawn")
2102
+ error_queue = ctx.SimpleQueue()
2102
2103
  processes = []
2103
2104
  for rank in range(num_processes):
2104
2105
  process = ctx.Process(
@@ -2115,6 +2116,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
2115
2116
  processor,
2116
2117
  rank,
2117
2118
  num_processes,
2119
+ error_queue,
2118
2120
  ),
2119
2121
  )
2120
2122
  process.start()
@@ -2127,8 +2129,18 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
2127
2129
 
2128
2130
  for process in processes:
2129
2131
  if process.exitcode not in (0, None):
2132
+ errors = []
2133
+ try:
2134
+ while not error_queue.empty():
2135
+ errors.append(error_queue.get_nowait())
2136
+ except Exception:
2137
+ pass
2138
+ error_text = (
2139
+ "\n\n".join(errors) if errors else "No worker traceback captured."
2140
+ )
2130
2141
  raise RuntimeError(
2131
- "[BaseModel-predict-streaming Error] One or more inference processes failed."
2142
+ "[BaseModel-predict-streaming Error] One or more inference processes failed.\n"
2143
+ + error_text
2132
2144
  )
2133
2145
  # Merge part files
2134
2146
  existing_parts = [p for p in part_paths if p.exists()]
@@ -2827,22 +2839,32 @@ def predict_streaming_worker(
2827
2839
  processor: Any | None,
2828
2840
  shard_rank: int,
2829
2841
  shard_count: int,
2842
+ error_queue: "mp.SimpleQueue[str] | None" = None,
2830
2843
  ) -> None:
2831
- model.eval()
2832
- model.predict_streaming(
2833
- data=data_path,
2834
- batch_size=batch_size,
2835
- save_path=save_path,
2836
- save_format=save_format,
2837
- stream_chunk_size=stream_chunk_size,
2838
- return_dataframe=False,
2839
- num_workers=num_workers,
2840
- prefetch_factor=prefetch_factor,
2841
- processor=processor,
2842
- num_processes=1,
2843
- shard_rank=shard_rank,
2844
- shard_count=shard_count,
2845
- )
2844
+ try:
2845
+ model.eval()
2846
+ model.predict_streaming(
2847
+ data=data_path,
2848
+ batch_size=batch_size,
2849
+ save_path=save_path,
2850
+ save_format=save_format,
2851
+ stream_chunk_size=stream_chunk_size,
2852
+ return_dataframe=False,
2853
+ num_workers=num_workers,
2854
+ prefetch_factor=prefetch_factor,
2855
+ processor=processor,
2856
+ num_processes=1,
2857
+ shard_rank=shard_rank,
2858
+ shard_count=shard_count,
2859
+ )
2860
+ except Exception:
2861
+ if error_queue is not None:
2862
+ import traceback
2863
+
2864
+ error_queue.put(
2865
+ f"[PredictWorker Error] rank={shard_rank}\n{traceback.format_exc()}"
2866
+ )
2867
+ raise
2846
2868
 
2847
2869
 
2848
2870
  class BaseMatchModel(BaseModel):
nextrec/cli.py CHANGED
@@ -25,7 +25,7 @@ import resource
25
25
  import sys
26
26
  import time
27
27
  from pathlib import Path
28
- from typing import Any, Dict, List
28
+ from typing import Any, Dict
29
29
 
30
30
  import pandas as pd
31
31
 
@@ -592,12 +592,7 @@ def predict_model(predict_config_path: str) -> None:
592
592
  model_file, map_location=predict_cfg.get("device", "cpu"), verbose=True
593
593
  )
594
594
 
595
- id_columns = []
596
- if predict_cfg.get("id_column"):
597
- id_columns = [predict_cfg["id_column"]]
598
- model.id_columns = id_columns
599
-
600
- effective_id_columns = id_columns or model.id_columns
595
+ effective_id_columns = model.id_columns
601
596
  log_cli_section("Features")
602
597
  log_kv_lines(
603
598
  [
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nextrec
3
- Version: 0.5.2
3
+ Version: 0.5.4
4
4
  Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
5
5
  Project-URL: Homepage, https://github.com/zerolovesea/NextRec
6
6
  Project-URL: Repository, https://github.com/zerolovesea/NextRec
@@ -73,7 +73,7 @@ Description-Content-Type: text/markdown
73
73
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
74
74
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
75
75
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
76
- ![Version](https://img.shields.io/badge/Version-0.5.2-orange.svg)
76
+ ![Version](https://img.shields.io/badge/Version-0.5.4-orange.svg)
77
77
  [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/zerolovesea/NextRec)
78
78
 
79
79
  中文文档 | [English Version](README_en.md)
@@ -259,11 +259,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
259
259
 
260
260
  预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
261
261
 
262
- > 截止当前版本0.5.2,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
262
+ > 截止当前版本0.5.4,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
263
263
 
264
264
  ## 兼容平台
265
265
 
266
- 当前最新版本为0.5.2,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
266
+ 当前最新版本为0.5.4,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
267
267
 
268
268
  | 平台 | 配置 |
269
269
  |------|------|
@@ -1,6 +1,6 @@
1
1
  nextrec/__init__.py,sha256=_M3oUqyuvQ5k8Th_3wId6hQ_caclh7M5ad51XN09m98,235
2
- nextrec/__version__.py,sha256=isJrmDBLRag7Zc2UK9ZovWGOv7ji1Oh-zJtJMNJFkXw,22
3
- nextrec/cli.py,sha256=ddYlk2zXF7iVFpONhG9Ofwb9dVgXfy112nk_JOHy1kI,29840
2
+ nextrec/__version__.py,sha256=DITpct-LrdIsTgwx2NgH5Ghx5y8Xgz1YMimy1ZV5RTY,22
3
+ nextrec/cli.py,sha256=IxYzZavCc5ACe-8i2EJbkdY66do4x7siGHSxoS1RoL0,29676
4
4
  nextrec/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
5
  nextrec/basic/activation.py,sha256=rU-W-DHgiD3AZnMGmD014ChxklfP9BpedDTiwtdgXhA,2762
6
6
  nextrec/basic/asserts.py,sha256=eaB4FZJ7Sbh9S8PLJZNiYd7Zi98ca2rDi4S7wSYCaEw,1473
@@ -10,7 +10,7 @@ nextrec/basic/heads.py,sha256=WqvavaH6Y8Au8dLaoUfH2AaOOWgYvjZI5US8avkQNsQ,4009
10
10
  nextrec/basic/layers.py,sha256=tawggQMMHlYTGpnubxUAvDPDJe_Lpq-HpLLCSjbJV54,37320
11
11
  nextrec/basic/loggers.py,sha256=jJUzUt_kMpjpV2Mqr7qBENWA1olEutTI7dFnpmndUUw,13845
12
12
  nextrec/basic/metrics.py,sha256=nVz3AkKwsxj_M97CoZWyQj8_Y9ZM_Icvw_QCM6c33Bc,26262
13
- nextrec/basic/model.py,sha256=Fp1nRVd-8xq6wgjny02afGdIPUhJ-x1jD9bIPHserjQ,135037
13
+ nextrec/basic/model.py,sha256=J8fkwQF2Y5YUajqucfSpsNuz8sbLqsE3mupCfbkZDhc,135869
14
14
  nextrec/basic/session.py,sha256=mrIsjRJhmvcAfoO1pXX-KB3SK5CCgz89wH8XDoAiGEI,4475
15
15
  nextrec/basic/summary.py,sha256=r5d8VtxJsgY6WRAVRa2-UcGb_sYA9VjMKDfQ31928qM,20494
16
16
  nextrec/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -80,8 +80,8 @@ nextrec/utils/model.py,sha256=K1-ZfS1umSTVo-lNHKNkBb3xvWnfiJJSTRicSVDpk4s,5148
80
80
  nextrec/utils/onnx_utils.py,sha256=KIVV_ELYzj3kCswfsSBZ1F2OnSwRJnXj7sxDBwBoBaA,8668
81
81
  nextrec/utils/torch_utils.py,sha256=_a9e6GXa3QKuu0E5RL44QRZ1iJSobbtNcPB3vtaCsu8,12313
82
82
  nextrec/utils/types.py,sha256=LFwYCBRo5WeYUh5LSCuyP1Lg9ez0Ih00Es3fUttGAFw,2273
83
- nextrec-0.5.2.dist-info/METADATA,sha256=ibFYpr07uouyfvIqXngb9GcezGwZe3gIK-E3lbNtavY,23464
84
- nextrec-0.5.2.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
85
- nextrec-0.5.2.dist-info/entry_points.txt,sha256=NN-dNSdfMRTv86bNXM7d3ZEPW2BQC6bRi7QP7i9cIps,45
86
- nextrec-0.5.2.dist-info/licenses/LICENSE,sha256=COP1BsqnEUwdx6GCkMjxOo5v3pUe4-Go_CdmQmSfYXM,1064
87
- nextrec-0.5.2.dist-info/RECORD,,
83
+ nextrec-0.5.4.dist-info/METADATA,sha256=OlEXZ7qVUN_3IsAFoDxmERQYxGro70BDo3v3dy_8zec,23464
84
+ nextrec-0.5.4.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
85
+ nextrec-0.5.4.dist-info/entry_points.txt,sha256=NN-dNSdfMRTv86bNXM7d3ZEPW2BQC6bRi7QP7i9cIps,45
86
+ nextrec-0.5.4.dist-info/licenses/LICENSE,sha256=COP1BsqnEUwdx6GCkMjxOo5v3pUe4-Go_CdmQmSfYXM,1064
87
+ nextrec-0.5.4.dist-info/RECORD,,