nextrec 0.4.9__tar.gz → 0.4.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. {nextrec-0.4.9 → nextrec-0.4.10}/PKG-INFO +5 -6
  2. {nextrec-0.4.9 → nextrec-0.4.10}/README.md +4 -5
  3. {nextrec-0.4.9 → nextrec-0.4.10}/README_en.md +3 -4
  4. {nextrec-0.4.9 → nextrec-0.4.10}/docs/en/Getting started guide.md +1 -1
  5. {nextrec-0.4.9 → nextrec-0.4.10}/docs/rtd/conf.py +1 -1
  6. {nextrec-0.4.9 → nextrec-0.4.10}/docs/zh//345/277/253/351/200/237/344/270/212/346/211/213.md +1 -1
  7. nextrec-0.4.10/nextrec/__version__.py +1 -0
  8. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/basic/model.py +4 -3
  9. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/cli.py +181 -34
  10. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/data/dataloader.py +19 -20
  11. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/ranking/deepfm.py +4 -5
  12. nextrec-0.4.10/nextrec/models/ranking/eulernet.py +365 -0
  13. nextrec-0.4.10/nextrec/models/ranking/lr.py +120 -0
  14. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec_cli_preset/NextRec-CLI.md +97 -64
  15. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec_cli_preset/NextRec-CLI_zh.md +92 -59
  16. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec_cli_preset/feature_config.yaml +2 -2
  17. nextrec-0.4.10/nextrec_cli_preset/predict_config.yaml +32 -0
  18. nextrec-0.4.10/nextrec_cli_preset/predict_config_template.yaml +64 -0
  19. nextrec-0.4.10/nextrec_cli_preset/train_config.yaml +37 -0
  20. nextrec-0.4.10/nextrec_cli_preset/train_config_template.yaml +149 -0
  21. {nextrec-0.4.9 → nextrec-0.4.10}/pyproject.toml +1 -1
  22. {nextrec-0.4.9 → nextrec-0.4.10}/test/test_ranking_models.py +90 -0
  23. {nextrec-0.4.9 → nextrec-0.4.10}/tutorials/notebooks/en/Hands on nextrec.ipynb +2 -2
  24. {nextrec-0.4.9 → nextrec-0.4.10}/tutorials/notebooks/zh//345/277/253/351/200/237/345/205/245/351/227/250nextrec.ipynb +2 -2
  25. {nextrec-0.4.9 → nextrec-0.4.10}/tutorials/run_all_ranking_models.py +4 -0
  26. nextrec-0.4.9/nextrec/__version__.py +0 -1
  27. nextrec-0.4.9/nextrec/models/ranking/eulernet.py +0 -0
  28. nextrec-0.4.9/nextrec/models/ranking/lr.py +0 -0
  29. nextrec-0.4.9/nextrec_cli_preset/predict_config.yaml +0 -24
  30. nextrec-0.4.9/nextrec_cli_preset/train_config.yaml +0 -45
  31. {nextrec-0.4.9 → nextrec-0.4.10}/.github/workflows/publish.yml +0 -0
  32. {nextrec-0.4.9 → nextrec-0.4.10}/.github/workflows/tests.yml +0 -0
  33. {nextrec-0.4.9 → nextrec-0.4.10}/.gitignore +0 -0
  34. {nextrec-0.4.9 → nextrec-0.4.10}/.readthedocs.yaml +0 -0
  35. {nextrec-0.4.9 → nextrec-0.4.10}/CODE_OF_CONDUCT.md +0 -0
  36. {nextrec-0.4.9 → nextrec-0.4.10}/CONTRIBUTING.md +0 -0
  37. {nextrec-0.4.9 → nextrec-0.4.10}/LICENSE +0 -0
  38. {nextrec-0.4.9 → nextrec-0.4.10}/MANIFEST.in +0 -0
  39. {nextrec-0.4.9 → nextrec-0.4.10}/assets/Feature Configuration.png +0 -0
  40. {nextrec-0.4.9 → nextrec-0.4.10}/assets/Model Parameters.png +0 -0
  41. {nextrec-0.4.9 → nextrec-0.4.10}/assets/Training Configuration.png +0 -0
  42. {nextrec-0.4.9 → nextrec-0.4.10}/assets/Training logs.png +0 -0
  43. {nextrec-0.4.9 → nextrec-0.4.10}/assets/logo.png +0 -0
  44. {nextrec-0.4.9 → nextrec-0.4.10}/assets/mmoe_tutorial.png +0 -0
  45. {nextrec-0.4.9 → nextrec-0.4.10}/assets/nextrec_diagram.png +0 -0
  46. {nextrec-0.4.9 → nextrec-0.4.10}/assets/test data.png +0 -0
  47. {nextrec-0.4.9 → nextrec-0.4.10}/dataset/ctcvr_task.csv +0 -0
  48. {nextrec-0.4.9 → nextrec-0.4.10}/dataset/ecommerce_task.csv +0 -0
  49. {nextrec-0.4.9 → nextrec-0.4.10}/dataset/match_task.csv +0 -0
  50. {nextrec-0.4.9 → nextrec-0.4.10}/dataset/movielens_100k.csv +0 -0
  51. {nextrec-0.4.9 → nextrec-0.4.10}/dataset/multitask_task.csv +0 -0
  52. {nextrec-0.4.9 → nextrec-0.4.10}/dataset/ranking_task.csv +0 -0
  53. {nextrec-0.4.9 → nextrec-0.4.10}/docs/rtd/Makefile +0 -0
  54. {nextrec-0.4.9 → nextrec-0.4.10}/docs/rtd/index.md +0 -0
  55. {nextrec-0.4.9 → nextrec-0.4.10}/docs/rtd/make.bat +0 -0
  56. {nextrec-0.4.9 → nextrec-0.4.10}/docs/rtd/modules.rst +0 -0
  57. {nextrec-0.4.9 → nextrec-0.4.10}/docs/rtd/nextrec.basic.rst +0 -0
  58. {nextrec-0.4.9 → nextrec-0.4.10}/docs/rtd/nextrec.data.rst +0 -0
  59. {nextrec-0.4.9 → nextrec-0.4.10}/docs/rtd/nextrec.loss.rst +0 -0
  60. {nextrec-0.4.9 → nextrec-0.4.10}/docs/rtd/nextrec.rst +0 -0
  61. {nextrec-0.4.9 → nextrec-0.4.10}/docs/rtd/nextrec.utils.rst +0 -0
  62. {nextrec-0.4.9 → nextrec-0.4.10}/docs/rtd/requirements.txt +0 -0
  63. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/__init__.py +0 -0
  64. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/basic/__init__.py +0 -0
  65. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/basic/activation.py +0 -0
  66. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/basic/callback.py +0 -0
  67. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/basic/features.py +0 -0
  68. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/basic/layers.py +0 -0
  69. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/basic/loggers.py +0 -0
  70. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/basic/metrics.py +0 -0
  71. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/basic/session.py +0 -0
  72. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/data/__init__.py +0 -0
  73. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/data/batch_utils.py +0 -0
  74. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/data/data_processing.py +0 -0
  75. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/data/data_utils.py +0 -0
  76. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/data/preprocessor.py +0 -0
  77. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/loss/__init__.py +0 -0
  78. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/loss/listwise.py +0 -0
  79. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/loss/loss_utils.py +0 -0
  80. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/loss/pairwise.py +0 -0
  81. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/loss/pointwise.py +0 -0
  82. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/generative/__init__.py +0 -0
  83. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/generative/hstu.py +0 -0
  84. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/generative/tiger.py +0 -0
  85. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/multi_task/__init__.py +0 -0
  86. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/multi_task/esmm.py +0 -0
  87. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/multi_task/mmoe.py +0 -0
  88. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/multi_task/ple.py +0 -0
  89. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/multi_task/poso.py +0 -0
  90. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/multi_task/share_bottom.py +0 -0
  91. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/ranking/__init__.py +0 -0
  92. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/ranking/afm.py +0 -0
  93. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/ranking/autoint.py +0 -0
  94. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/ranking/dcn.py +0 -0
  95. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/ranking/dcn_v2.py +0 -0
  96. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/ranking/dien.py +0 -0
  97. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/ranking/din.py +0 -0
  98. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/ranking/ffm.py +0 -0
  99. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/ranking/fibinet.py +0 -0
  100. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/ranking/fm.py +0 -0
  101. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/ranking/masknet.py +0 -0
  102. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/ranking/pnn.py +0 -0
  103. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/ranking/widedeep.py +0 -0
  104. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/ranking/xdeepfm.py +0 -0
  105. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/representation/__init__.py +0 -0
  106. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/representation/rqvae.py +0 -0
  107. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/retrieval/__init__.py +0 -0
  108. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/retrieval/dssm.py +0 -0
  109. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/retrieval/dssm_v2.py +0 -0
  110. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/retrieval/mind.py +0 -0
  111. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/retrieval/sdm.py +0 -0
  112. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/models/retrieval/youtube_dnn.py +0 -0
  113. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/utils/__init__.py +0 -0
  114. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/utils/config.py +0 -0
  115. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/utils/console.py +0 -0
  116. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/utils/data.py +0 -0
  117. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/utils/embedding.py +0 -0
  118. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/utils/feature.py +0 -0
  119. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/utils/model.py +0 -0
  120. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec/utils/torch_utils.py +0 -0
  121. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/afm.yaml +0 -0
  122. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/autoint.yaml +0 -0
  123. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/dcn.yaml +0 -0
  124. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/deepfm.yaml +0 -0
  125. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/din.yaml +0 -0
  126. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/esmm.yaml +0 -0
  127. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/fibinet.yaml +0 -0
  128. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/fm.yaml +0 -0
  129. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/masknet.yaml +0 -0
  130. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/mmoe.yaml +0 -0
  131. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/ple.yaml +0 -0
  132. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/pnn.yaml +0 -0
  133. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/poso.yaml +0 -0
  134. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/share_bottom.yaml +0 -0
  135. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/widedeep.yaml +0 -0
  136. {nextrec-0.4.9 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/xdeepfm.yaml +0 -0
  137. {nextrec-0.4.9 → nextrec-0.4.10}/pytest.ini +0 -0
  138. {nextrec-0.4.9 → nextrec-0.4.10}/requirements.txt +0 -0
  139. {nextrec-0.4.9 → nextrec-0.4.10}/scripts/format_code.py +0 -0
  140. {nextrec-0.4.9 → nextrec-0.4.10}/test/__init__.py +0 -0
  141. {nextrec-0.4.9 → nextrec-0.4.10}/test/conftest.py +0 -0
  142. {nextrec-0.4.9 → nextrec-0.4.10}/test/helpers.py +0 -0
  143. {nextrec-0.4.9 → nextrec-0.4.10}/test/run_tests.py +0 -0
  144. {nextrec-0.4.9 → nextrec-0.4.10}/test/test_base_model_regularization.py +0 -0
  145. {nextrec-0.4.9 → nextrec-0.4.10}/test/test_generative_models.py +0 -0
  146. {nextrec-0.4.9 → nextrec-0.4.10}/test/test_layers.py +0 -0
  147. {nextrec-0.4.9 → nextrec-0.4.10}/test/test_losses.py +0 -0
  148. {nextrec-0.4.9 → nextrec-0.4.10}/test/test_match_models.py +0 -0
  149. {nextrec-0.4.9 → nextrec-0.4.10}/test/test_multitask_models.py +0 -0
  150. {nextrec-0.4.9 → nextrec-0.4.10}/test/test_preprocessor.py +0 -0
  151. {nextrec-0.4.9 → nextrec-0.4.10}/test/test_utils_console.py +0 -0
  152. {nextrec-0.4.9 → nextrec-0.4.10}/test/test_utils_data.py +0 -0
  153. {nextrec-0.4.9 → nextrec-0.4.10}/test/test_utils_embedding.py +0 -0
  154. {nextrec-0.4.9 → nextrec-0.4.10}/test_requirements.txt +0 -0
  155. {nextrec-0.4.9 → nextrec-0.4.10}/tutorials/distributed/example_distributed_training.py +0 -0
  156. {nextrec-0.4.9 → nextrec-0.4.10}/tutorials/distributed/example_distributed_training_large_dataset.py +0 -0
  157. {nextrec-0.4.9 → nextrec-0.4.10}/tutorials/example_match_dssm.py +0 -0
  158. {nextrec-0.4.9 → nextrec-0.4.10}/tutorials/example_multitask.py +0 -0
  159. {nextrec-0.4.9 → nextrec-0.4.10}/tutorials/example_ranking_din.py +0 -0
  160. {nextrec-0.4.9 → nextrec-0.4.10}/tutorials/movielen_match_dssm.py +0 -0
  161. {nextrec-0.4.9 → nextrec-0.4.10}/tutorials/movielen_ranking_deepfm.py +0 -0
  162. {nextrec-0.4.9 → nextrec-0.4.10}/tutorials/notebooks/en/Build semantic ID with RQ-VAE.ipynb +0 -0
  163. {nextrec-0.4.9 → nextrec-0.4.10}/tutorials/notebooks/en/Hands on dataprocessor.ipynb +0 -0
  164. {nextrec-0.4.9 → nextrec-0.4.10}/tutorials/notebooks/zh//344/275/277/347/224/250RQ-VAE/346/236/204/345/273/272/350/257/255/344/271/211ID.ipynb" +0 -0
  165. {nextrec-0.4.9 → nextrec-0.4.10}/tutorials/notebooks/zh//345/246/202/344/275/225/344/275/277/347/224/250DataProcessor/350/277/233/350/241/214/351/242/204/345/244/204/347/220/206.ipynb" +0 -0
  166. {nextrec-0.4.9 → nextrec-0.4.10}/tutorials/run_all_match_models.py +0 -0
  167. {nextrec-0.4.9 → nextrec-0.4.10}/tutorials/run_all_multitask_models.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nextrec
3
- Version: 0.4.9
3
+ Version: 0.4.10
4
4
  Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
5
5
  Project-URL: Homepage, https://github.com/zerolovesea/NextRec
6
6
  Project-URL: Repository, https://github.com/zerolovesea/NextRec
@@ -66,7 +66,7 @@ Description-Content-Type: text/markdown
66
66
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
67
67
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
68
68
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
69
- ![Version](https://img.shields.io/badge/Version-0.4.9-orange.svg)
69
+ ![Version](https://img.shields.io/badge/Version-0.4.10-orange.svg)
70
70
 
71
71
  中文文档 | [English Version](README_en.md)
72
72
 
@@ -99,11 +99,10 @@ NextRec是一个基于PyTorch的现代推荐系统框架,旨在为研究工程
99
99
 
100
100
  ## NextRec近期进展
101
101
 
102
- - **12/12/2025** 在v0.4.9中加入了[RQ-VAE](/nextrec/models/representation/rqvae.py)模块。配套的[数据集](/dataset/ecommerce_task.csv)和[代码](tutorials/notebooks/zh/使用RQ-VAE构建语义ID.ipynb)已经同步在仓库中
102
+ - **12/12/2025** 在v0.4.10中加入了[RQ-VAE](/nextrec/models/representation/rqvae.py)模块。配套的[数据集](/dataset/ecommerce_task.csv)和[代码](tutorials/notebooks/zh/使用RQ-VAE构建语义ID.ipynb)已经同步在仓库中
103
103
  - **07/12/2025** 发布了NextRec CLI命令行工具,它允许用户根据配置文件进行一键训练和推理,我们提供了相关的[教程](/nextrec_cli_preset/NextRec-CLI_zh.md)和[教学代码](/nextrec_cli_preset)
104
104
  - **03/12/2025** NextRec获得了100颗🌟!感谢大家的支持
105
105
  - **06/12/2025** 在v0.4.1中支持了单机多卡的分布式DDP训练,并且提供了配套的[代码](tutorials/distributed)
106
- - **23/11/2025** 在v0.2.2中对basemodel进行了逻辑上的大幅重构和流程统一,并且对listwise/pairwise/pointwise损失进行了统一
107
106
  - **11/11/2025** NextRec v0.1.0发布,我们提供了10余种Ranking模型,4种多任务模型和4种召回模型,以及统一的训练/日志/指标管理系统
108
107
 
109
108
  ## 架构
@@ -241,11 +240,11 @@ nextrec --mode=train --train_config=path/to/train_config.yaml
241
240
  nextrec --mode=predict --predict_config=path/to/predict_config.yaml
242
241
  ```
243
242
 
244
- > 截止当前版本0.4.9,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
243
+ > 截止当前版本0.4.10,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
245
244
 
246
245
  ## 兼容平台
247
246
 
248
- 当前最新版本为0.4.9,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
247
+ 当前最新版本为0.4.10,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
249
248
 
250
249
  | 平台 | 配置 |
251
250
  |------|------|
@@ -7,7 +7,7 @@
7
7
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
8
8
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
9
9
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
10
- ![Version](https://img.shields.io/badge/Version-0.4.9-orange.svg)
10
+ ![Version](https://img.shields.io/badge/Version-0.4.10-orange.svg)
11
11
 
12
12
  中文文档 | [English Version](README_en.md)
13
13
 
@@ -40,11 +40,10 @@ NextRec是一个基于PyTorch的现代推荐系统框架,旨在为研究工程
40
40
 
41
41
  ## NextRec近期进展
42
42
 
43
- - **12/12/2025** 在v0.4.9中加入了[RQ-VAE](/nextrec/models/representation/rqvae.py)模块。配套的[数据集](/dataset/ecommerce_task.csv)和[代码](tutorials/notebooks/zh/使用RQ-VAE构建语义ID.ipynb)已经同步在仓库中
43
+ - **12/12/2025** 在v0.4.10中加入了[RQ-VAE](/nextrec/models/representation/rqvae.py)模块。配套的[数据集](/dataset/ecommerce_task.csv)和[代码](tutorials/notebooks/zh/使用RQ-VAE构建语义ID.ipynb)已经同步在仓库中
44
44
  - **07/12/2025** 发布了NextRec CLI命令行工具,它允许用户根据配置文件进行一键训练和推理,我们提供了相关的[教程](/nextrec_cli_preset/NextRec-CLI_zh.md)和[教学代码](/nextrec_cli_preset)
45
45
  - **03/12/2025** NextRec获得了100颗🌟!感谢大家的支持
46
46
  - **06/12/2025** 在v0.4.1中支持了单机多卡的分布式DDP训练,并且提供了配套的[代码](tutorials/distributed)
47
- - **23/11/2025** 在v0.2.2中对basemodel进行了逻辑上的大幅重构和流程统一,并且对listwise/pairwise/pointwise损失进行了统一
48
47
  - **11/11/2025** NextRec v0.1.0发布,我们提供了10余种Ranking模型,4种多任务模型和4种召回模型,以及统一的训练/日志/指标管理系统
49
48
 
50
49
  ## 架构
@@ -182,11 +181,11 @@ nextrec --mode=train --train_config=path/to/train_config.yaml
182
181
  nextrec --mode=predict --predict_config=path/to/predict_config.yaml
183
182
  ```
184
183
 
185
- > 截止当前版本0.4.9,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
184
+ > 截止当前版本0.4.10,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
186
185
 
187
186
  ## 兼容平台
188
187
 
189
- 当前最新版本为0.4.9,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
188
+ 当前最新版本为0.4.10,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
190
189
 
191
190
  | 平台 | 配置 |
192
191
  |------|------|
@@ -7,7 +7,7 @@
7
7
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
8
8
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
9
9
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
10
- ![Version](https://img.shields.io/badge/Version-0.4.9-orange.svg)
10
+ ![Version](https://img.shields.io/badge/Version-0.4.10-orange.svg)
11
11
 
12
12
  English | [中文文档](README.md)
13
13
 
@@ -46,7 +46,6 @@ NextRec is a modern recommendation framework built on PyTorch, delivering a unif
46
46
  - **07/12/2025** Released the NextRec CLI tool to run training/inference from configs. See the [guide](/nextrec_cli_preset/NextRec-CLI.md) and [reference code](/nextrec_cli_preset).
47
47
  - **03/12/2025** NextRec reached 100 ⭐—thanks for the support!
48
48
  - **06/12/2025** Added single-machine multi-GPU DDP training in v0.4.1 with supporting [code](tutorials/distributed).
49
- - **23/11/2025** Major logical refactor of basemodel and unification of listwise/pairwise/pointwise losses in v0.2.2.
50
49
  - **11/11/2025** NextRec v0.1.0 released with 10+ ranking models, 4 multi-task models, 4 retrieval models, and a unified training/logging/metrics system.
51
50
 
52
51
  ## Architecture
@@ -186,11 +185,11 @@ nextrec --mode=train --train_config=path/to/train_config.yaml
186
185
  nextrec --mode=predict --predict_config=path/to/predict_config.yaml
187
186
  ```
188
187
 
189
- > As of version 0.4.9, NextRec CLI supports single-machine training; distributed training features are currently under development.
188
+ > As of version 0.4.10, NextRec CLI supports single-machine training; distributed training features are currently under development.
190
189
 
191
190
  ## Platform Compatibility
192
191
 
193
- The current version is 0.4.9. All models and test code have been validated on the following platforms. If you encounter compatibility issues, please report them in the issue tracker with your system version:
192
+ The current version is 0.4.10. All models and test code have been validated on the following platforms. If you encounter compatibility issues, please report them in the issue tracker with your system version:
194
193
 
195
194
  | Platform | Configuration |
196
195
  |----------|---------------|
@@ -102,4 +102,4 @@ metrics = model.evaluate(
102
102
  - Multi-task: `tutorials/example_multitask.py`
103
103
  - Notebooks: `tutorials/notebooks/zh/Hands on nextrec.ipynb`, `tutorials/notebooks/zh/Hands on dataprocessor.ipynb`
104
104
 
105
- For large offline features or streaming loads, use `DataProcessor` and `RecDataLoader` to configure CSV/Parquet paths and streaming (`load_full=False`) without changing model code.
105
+ For large offline features or streaming loads, use `DataProcessor` and `RecDataLoader` to configure CSV/Parquet paths and streaming (`streaming=True`) without changing model code.
@@ -11,7 +11,7 @@ sys.path.insert(0, str(PROJECT_ROOT / "nextrec"))
11
11
  project = "NextRec"
12
12
  copyright = "2025, Yang Zhou"
13
13
  author = "Yang Zhou"
14
- release = "0.4.9"
14
+ release = "0.4.10"
15
15
 
16
16
  extensions = [
17
17
  "myst_parser",
@@ -102,4 +102,4 @@ metrics = model.evaluate(
102
102
  - 多任务:`tutorials/example_multitask.py`
103
103
  - Notebook:`tutorials/notebooks/zh/Hands on nextrec.ipynb`、`tutorials/notebooks/zh/Hands on dataprocessor.ipynb`
104
104
 
105
- 如果需要大规模离线特征或流式加载,可结合 `DataProcessor`、`RecDataLoader` 配置 CSV/Parquet 路径与流式参数(`load_full=False`),在不修改模型代码的情况下完成训练与推理。
105
+ 如果需要大规模离线特征或流式加载,可结合 `DataProcessor`、`RecDataLoader` 配置 CSV/Parquet 路径与流式参数(`streaming=True`),在不修改模型代码的情况下完成训练与推理。
@@ -0,0 +1 @@
1
+ __version__ = "0.4.10"
@@ -1376,7 +1376,7 @@ class BaseModel(FeatureSet, nn.Module):
1376
1376
  data=data,
1377
1377
  batch_size=batch_size,
1378
1378
  shuffle=False,
1379
- load_full=False,
1379
+ streaming=True,
1380
1380
  chunk_size=streaming_chunk_size,
1381
1381
  )
1382
1382
  else:
@@ -1510,7 +1510,7 @@ class BaseModel(FeatureSet, nn.Module):
1510
1510
  data=data,
1511
1511
  batch_size=batch_size,
1512
1512
  shuffle=False,
1513
- load_full=False,
1513
+ streaming=True,
1514
1514
  chunk_size=streaming_chunk_size,
1515
1515
  )
1516
1516
  elif not isinstance(data, DataLoader):
@@ -1605,7 +1605,8 @@ class BaseModel(FeatureSet, nn.Module):
1605
1605
  if collected_frames
1606
1606
  else pd.DataFrame(columns=pred_columns or [])
1607
1607
  )
1608
- return pd.DataFrame(columns=pred_columns or [])
1608
+ # Return the actual save path when not returning dataframe
1609
+ return target_path
1609
1610
 
1610
1611
  def save_model(
1611
1612
  self,
@@ -29,7 +29,7 @@ from typing import Any, Dict, List
29
29
  import pandas as pd
30
30
 
31
31
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
32
- from nextrec.basic.loggers import setup_logger
32
+ from nextrec.basic.loggers import colorize, format_kv, setup_logger
33
33
  from nextrec.data.data_utils import split_dict_random
34
34
  from nextrec.data.dataloader import RecDataLoader
35
35
  from nextrec.data.preprocessor import DataProcessor
@@ -52,6 +52,17 @@ from nextrec.utils.feature import normalize_to_list
52
52
  logger = logging.getLogger(__name__)
53
53
 
54
54
 
55
+ def log_cli_section(title: str) -> None:
56
+ logger.info("")
57
+ logger.info(colorize(f"[{title}]", color="bright_blue", bold=True))
58
+ logger.info(colorize("-" * 80, color="bright_blue"))
59
+
60
+
61
+ def log_kv_lines(items: list[tuple[str, Any]]) -> None:
62
+ for label, value in items:
63
+ logger.info(format_kv(label, value))
64
+
65
+
55
66
  def train_model(train_config_path: str) -> None:
56
67
  """
57
68
  Train a NextRec model using the provided configuration file.
@@ -74,8 +85,17 @@ def train_model(train_config_path: str) -> None:
74
85
  artifact_root = Path(session_cfg.get("artifact_root", "nextrec_logs"))
75
86
  session_dir = artifact_root / session_id
76
87
  setup_logger(session_id=session_id)
77
- logger.info(
78
- f"[NextRec CLI] Training start | version={get_nextrec_version()} | session_id={session_id} | artifacts={session_dir.resolve()}"
88
+
89
+ log_cli_section("CLI")
90
+ log_kv_lines(
91
+ [
92
+ ("Mode", "train"),
93
+ ("Version", get_nextrec_version()),
94
+ ("Session ID", session_id),
95
+ ("Artifacts", session_dir.resolve()),
96
+ ("Config", config_file.resolve()),
97
+ ("Command", " ".join(sys.argv)),
98
+ ]
79
99
  )
80
100
 
81
101
  processor_path = session_dir / "processor.pkl"
@@ -102,11 +122,53 @@ def train_model(train_config_path: str) -> None:
102
122
  cfg.get("model_config", "model_config.yaml"), config_dir
103
123
  )
104
124
 
125
+ log_cli_section("Config")
126
+ log_kv_lines(
127
+ [
128
+ ("Train config", config_file.resolve()),
129
+ ("Feature config", feature_cfg_path),
130
+ ("Model config", model_cfg_path),
131
+ ]
132
+ )
133
+
105
134
  feature_cfg = read_yaml(feature_cfg_path)
106
135
  model_cfg = read_yaml(model_cfg_path)
107
136
 
137
+ # Extract id_column from data config for GAUC metrics
138
+ id_column = data_cfg.get("id_column") or data_cfg.get("user_id_column")
139
+ id_columns = [id_column] if id_column else []
140
+
141
+ log_cli_section("Data")
142
+ log_kv_lines(
143
+ [
144
+ ("Data path", data_path),
145
+ ("Format", data_cfg.get("format", "auto")),
146
+ ("Streaming", streaming),
147
+ ("Target", target),
148
+ ("ID column", id_column or "(not set)"),
149
+ ]
150
+ )
151
+ if data_cfg.get("valid_ratio") is not None:
152
+ logger.info(format_kv("Valid ratio", data_cfg.get("valid_ratio")))
153
+ if data_cfg.get("val_path") or data_cfg.get("valid_path"):
154
+ logger.info(
155
+ format_kv(
156
+ "Validation path",
157
+ resolve_path(
158
+ data_cfg.get("val_path") or data_cfg.get("valid_path"), config_dir
159
+ ),
160
+ )
161
+ )
162
+
108
163
  if streaming:
109
164
  file_paths, file_type = resolve_file_paths(str(data_path))
165
+ log_kv_lines(
166
+ [
167
+ ("File type", file_type),
168
+ ("Files", len(file_paths)),
169
+ ("Chunk size", dataloader_chunk_size),
170
+ ]
171
+ )
110
172
  first_file = file_paths[0]
111
173
  first_chunk_size = max(1, min(dataloader_chunk_size, 1000))
112
174
  chunk_iter = iter_file_chunks(first_file, file_type, first_chunk_size)
@@ -118,14 +180,12 @@ def train_model(train_config_path: str) -> None:
118
180
 
119
181
  else:
120
182
  df = read_table(data_path, data_cfg.get("format"))
183
+ logger.info(format_kv("Rows", len(df)))
184
+ logger.info(format_kv("Columns", len(df.columns)))
121
185
  df_columns = list(df.columns)
122
186
 
123
187
  dense_names, sparse_names, sequence_names = select_features(feature_cfg, df_columns)
124
188
 
125
- # Extract id_column from data config for GAUC metrics
126
- id_column = data_cfg.get("id_column") or data_cfg.get("user_id_column")
127
- id_columns = [id_column] if id_column else []
128
-
129
189
  used_columns = dense_names + sparse_names + sequence_names + target + id_columns
130
190
 
131
191
  # keep order but drop duplicates
@@ -141,6 +201,17 @@ def train_model(train_config_path: str) -> None:
141
201
  processor, feature_cfg, dense_names, sparse_names, sequence_names
142
202
  )
143
203
 
204
+ log_cli_section("Features")
205
+ log_kv_lines(
206
+ [
207
+ ("Dense features", len(dense_names)),
208
+ ("Sparse features", len(sparse_names)),
209
+ ("Sequence features", len(sequence_names)),
210
+ ("Targets", len(target)),
211
+ ("Used columns", len(unique_used_columns)),
212
+ ]
213
+ )
214
+
144
215
  if streaming:
145
216
  processor.fit(str(data_path), chunk_size=dataloader_chunk_size)
146
217
  processed = None
@@ -244,7 +315,7 @@ def train_model(train_config_path: str) -> None:
244
315
  data=train_stream_source,
245
316
  batch_size=dataloader_cfg.get("train_batch_size", 512),
246
317
  shuffle=dataloader_cfg.get("train_shuffle", True),
247
- load_full=False,
318
+ streaming=True,
248
319
  chunk_size=dataloader_chunk_size,
249
320
  num_workers=dataloader_cfg.get("num_workers", 0),
250
321
  )
@@ -255,7 +326,7 @@ def train_model(train_config_path: str) -> None:
255
326
  data=str(val_data_resolved),
256
327
  batch_size=dataloader_cfg.get("valid_batch_size", 512),
257
328
  shuffle=dataloader_cfg.get("valid_shuffle", False),
258
- load_full=False,
329
+ streaming=True,
259
330
  chunk_size=dataloader_chunk_size,
260
331
  num_workers=dataloader_cfg.get("num_workers", 0),
261
332
  )
@@ -264,7 +335,7 @@ def train_model(train_config_path: str) -> None:
264
335
  data=streaming_valid_files,
265
336
  batch_size=dataloader_cfg.get("valid_batch_size", 512),
266
337
  shuffle=dataloader_cfg.get("valid_shuffle", False),
267
- load_full=False,
338
+ streaming=True,
268
339
  chunk_size=dataloader_chunk_size,
269
340
  num_workers=dataloader_cfg.get("num_workers", 0),
270
341
  )
@@ -295,6 +366,15 @@ def train_model(train_config_path: str) -> None:
295
366
  device,
296
367
  )
297
368
 
369
+ log_cli_section("Model")
370
+ log_kv_lines(
371
+ [
372
+ ("Model", model.__class__.__name__),
373
+ ("Device", device),
374
+ ("Session ID", session_id),
375
+ ]
376
+ )
377
+
298
378
  model.compile(
299
379
  optimizer=train_cfg.get("optimizer", "adam"),
300
380
  optimizer_params=train_cfg.get("optimizer_params", {}),
@@ -325,13 +405,30 @@ def predict_model(predict_config_path: str) -> None:
325
405
  config_dir = config_file.resolve().parent
326
406
  cfg = read_yaml(config_file)
327
407
 
328
- session_cfg = cfg.get("session", {}) or {}
329
- session_id = session_cfg.get("id", "masknet_tutorial")
330
- artifact_root = Path(session_cfg.get("artifact_root", "nextrec_logs"))
331
- session_dir = Path(cfg.get("checkpoint_path") or (artifact_root / session_id))
408
+ # Checkpoint path is the primary configuration
409
+ if "checkpoint_path" not in cfg:
410
+ session_cfg = cfg.get("session", {}) or {}
411
+ session_id = session_cfg.get("id", "nextrec_session")
412
+ artifact_root = Path(session_cfg.get("artifact_root", "nextrec_logs"))
413
+ session_dir = artifact_root / session_id
414
+ else:
415
+ session_dir = Path(cfg["checkpoint_path"])
416
+ # Auto-infer session_id from checkpoint directory name
417
+ session_cfg = cfg.get("session", {}) or {}
418
+ session_id = session_cfg.get("id") or session_dir.name
419
+
332
420
  setup_logger(session_id=session_id)
333
- logger.info(
334
- f"[NextRec CLI] Predict start | version={get_nextrec_version()} | session_id={session_id} | checkpoint={session_dir.resolve()}"
421
+
422
+ log_cli_section("CLI")
423
+ log_kv_lines(
424
+ [
425
+ ("Mode", "predict"),
426
+ ("Version", get_nextrec_version()),
427
+ ("Session ID", session_id),
428
+ ("Checkpoint", session_dir.resolve()),
429
+ ("Config", config_file.resolve()),
430
+ ("Command", " ".join(sys.argv)),
431
+ ]
335
432
  )
336
433
 
337
434
  processor_path = Path(session_dir / "processor.pkl")
@@ -339,24 +436,38 @@ def predict_model(predict_config_path: str) -> None:
339
436
  processor_path = session_dir / "processor" / "processor.pkl"
340
437
 
341
438
  predict_cfg = cfg.get("predict", {}) or {}
342
- model_cfg_path = resolve_path(
343
- cfg.get("model_config", "model_config.yaml"), config_dir
344
- )
345
- # feature_cfg_path = resolve_path(
346
- # cfg.get("feature_config", "feature_config.yaml"), config_dir
347
- # )
439
+
440
+ # Auto-find model_config in checkpoint directory if not specified
441
+ if "model_config" in cfg:
442
+ model_cfg_path = resolve_path(cfg["model_config"], config_dir)
443
+ else:
444
+ # Try to find model_config.yaml in checkpoint directory
445
+ auto_model_cfg = session_dir / "model_config.yaml"
446
+ if auto_model_cfg.exists():
447
+ model_cfg_path = auto_model_cfg
448
+ else:
449
+ # Fallback to config directory
450
+ model_cfg_path = resolve_path("model_config.yaml", config_dir)
348
451
 
349
452
  model_cfg = read_yaml(model_cfg_path)
350
- # feature_cfg = read_yaml(feature_cfg_path)
351
453
  model_cfg.setdefault("session_id", session_id)
352
454
  model_cfg.setdefault("params", {})
353
455
 
456
+ log_cli_section("Config")
457
+ log_kv_lines(
458
+ [
459
+ ("Predict config", config_file.resolve()),
460
+ ("Model config", model_cfg_path),
461
+ ("Processor", processor_path),
462
+ ]
463
+ )
464
+
354
465
  processor = DataProcessor.load(processor_path)
355
466
 
356
467
  # Load checkpoint and ensure required parameters are passed
357
468
  checkpoint_base = Path(session_dir)
358
469
  if checkpoint_base.is_dir():
359
- candidates = sorted(checkpoint_base.glob("*.model"))
470
+ candidates = sorted(checkpoint_base.glob("*.pt"))
360
471
  if not candidates:
361
472
  raise FileNotFoundError(
362
473
  f"[NextRec CLI Error]: Unable to find model checkpoint: {checkpoint_base}"
@@ -365,7 +476,7 @@ def predict_model(predict_config_path: str) -> None:
365
476
  config_dir_for_features = checkpoint_base
366
477
  else:
367
478
  model_file = (
368
- checkpoint_base.with_suffix(".model")
479
+ checkpoint_base.with_suffix(".pt")
369
480
  if checkpoint_base.suffix == ""
370
481
  else checkpoint_base
371
482
  )
@@ -415,40 +526,78 @@ def predict_model(predict_config_path: str) -> None:
415
526
  id_columns = [predict_cfg["id_column"]]
416
527
  model.id_columns = id_columns
417
528
 
529
+ effective_id_columns = id_columns or model.id_columns
530
+ log_cli_section("Features")
531
+ log_kv_lines(
532
+ [
533
+ ("Dense features", len(dense_features)),
534
+ ("Sparse features", len(sparse_features)),
535
+ ("Sequence features", len(sequence_features)),
536
+ ("Targets", len(target_cols)),
537
+ ("ID columns", len(effective_id_columns)),
538
+ ]
539
+ )
540
+
541
+ log_cli_section("Model")
542
+ log_kv_lines(
543
+ [
544
+ ("Model", model.__class__.__name__),
545
+ ("Checkpoint", model_file),
546
+ ("Device", predict_cfg.get("device", "cpu")),
547
+ ]
548
+ )
549
+
418
550
  rec_dataloader = RecDataLoader(
419
551
  dense_features=model.dense_features,
420
552
  sparse_features=model.sparse_features,
421
553
  sequence_features=model.sequence_features,
422
554
  target=None,
423
- id_columns=id_columns or model.id_columns,
555
+ id_columns=effective_id_columns,
424
556
  processor=processor,
425
557
  )
426
558
 
427
559
  data_path = resolve_path(predict_cfg["data_path"], config_dir)
428
560
  batch_size = predict_cfg.get("batch_size", 512)
429
561
 
562
+ log_cli_section("Data")
563
+ log_kv_lines(
564
+ [
565
+ ("Data path", data_path),
566
+ ("Format", predict_cfg.get("source_data_format", predict_cfg.get("data_format", "auto"))),
567
+ ("Batch size", batch_size),
568
+ ("Chunk size", predict_cfg.get("chunk_size", 20000)),
569
+ ("Streaming", predict_cfg.get("streaming", True)),
570
+ ]
571
+ )
572
+ logger.info("")
430
573
  pred_loader = rec_dataloader.create_dataloader(
431
574
  data=str(data_path),
432
575
  batch_size=batch_size,
433
576
  shuffle=False,
434
- load_full=predict_cfg.get("load_full", False),
577
+ streaming=predict_cfg.get("streaming", True),
435
578
  chunk_size=predict_cfg.get("chunk_size", 20000),
436
579
  )
437
580
 
438
- output_path = resolve_path(predict_cfg["output_path"], config_dir)
439
- output_path.parent.mkdir(parents=True, exist_ok=True)
581
+ # Build output path: {checkpoint_path}/predictions/{name}.{save_data_format}
582
+ save_format = predict_cfg.get("save_data_format", predict_cfg.get("save_format", "csv"))
583
+ pred_name = predict_cfg.get("name", "pred")
584
+ # Pass filename with extension to let model.predict handle path resolution
585
+ save_path = f"{pred_name}.{save_format}"
440
586
 
441
587
  start = time.time()
442
- model.predict(
588
+ logger.info("")
589
+ result = model.predict(
443
590
  data=pred_loader,
444
591
  batch_size=batch_size,
445
592
  include_ids=bool(id_columns),
446
593
  return_dataframe=False,
447
- save_path=output_path,
448
- save_format=predict_cfg.get("save_format", "csv"),
594
+ save_path=save_path,
595
+ save_format=save_format,
449
596
  num_workers=predict_cfg.get("num_workers", 0),
450
597
  )
451
598
  duration = time.time() - start
599
+ # When return_dataframe=False, result is the actual file path
600
+ output_path = result if isinstance(result, Path) else checkpoint_base / "predictions" / save_path
452
601
  logger.info(f"Prediction completed, results saved to: {output_path}")
453
602
  logger.info(f"Total time: {duration:.2f} seconds")
454
603
 
@@ -492,8 +641,6 @@ Examples:
492
641
  parser.add_argument("--predict_config", help="Prediction configuration file path")
493
642
  args = parser.parse_args()
494
643
 
495
- logger.info(get_nextrec_version())
496
-
497
644
  if not args.mode:
498
645
  parser.error("[NextRec CLI Error] --mode is required (train|predict)")
499
646
 
@@ -102,9 +102,8 @@ class FileDataset(FeatureSet, IterableDataset):
102
102
  self.current_file_index = 0
103
103
  for file_path in self.file_paths:
104
104
  self.current_file_index += 1
105
- if self.total_files == 1:
106
- file_name = os.path.basename(file_path)
107
- logging.info(f"Processing file: {file_name}")
105
+ # Don't log file processing here to avoid interrupting progress bars
106
+ # File information is already displayed in the CLI data section
108
107
  if self.file_type == "csv":
109
108
  yield from self.read_csv_chunks(file_path)
110
109
  elif self.file_type == "parquet":
@@ -190,7 +189,7 @@ class RecDataLoader(FeatureSet):
190
189
  ),
191
190
  batch_size: int = 32,
192
191
  shuffle: bool = True,
193
- load_full: bool = True,
192
+ streaming: bool = False,
194
193
  chunk_size: int = 10000,
195
194
  num_workers: int = 0,
196
195
  sampler=None,
@@ -202,7 +201,7 @@ class RecDataLoader(FeatureSet):
202
201
  data: Data source, can be a dict, pd.DataFrame, file path (str), or existing DataLoader.
203
202
  batch_size: Batch size for DataLoader.
204
203
  shuffle: Whether to shuffle the data (ignored in streaming mode).
205
- load_full: If True, load full data into memory; if False, use streaming mode for large files.
204
+ streaming: If True, use streaming mode for large files; if False, load full data into memory.
206
205
  chunk_size: Chunk size for streaming mode (number of rows per chunk).
207
206
  num_workers: Number of worker processes for data loading.
208
207
  sampler: Optional sampler for DataLoader, only used for distributed training.
@@ -217,7 +216,7 @@ class RecDataLoader(FeatureSet):
217
216
  path=data,
218
217
  batch_size=batch_size,
219
218
  shuffle=shuffle,
220
- load_full=load_full,
219
+ streaming=streaming,
221
220
  chunk_size=chunk_size,
222
221
  num_workers=num_workers,
223
222
  )
@@ -230,7 +229,7 @@ class RecDataLoader(FeatureSet):
230
229
  path=data,
231
230
  batch_size=batch_size,
232
231
  shuffle=shuffle,
233
- load_full=load_full,
232
+ streaming=streaming,
234
233
  chunk_size=chunk_size,
235
234
  num_workers=num_workers,
236
235
  )
@@ -290,7 +289,7 @@ class RecDataLoader(FeatureSet):
290
289
  path: str | os.PathLike | list[str] | list[os.PathLike],
291
290
  batch_size: int,
292
291
  shuffle: bool,
293
- load_full: bool,
292
+ streaming: bool,
294
293
  chunk_size: int = 10000,
295
294
  num_workers: int = 0,
296
295
  ) -> DataLoader:
@@ -311,8 +310,17 @@ class RecDataLoader(FeatureSet):
311
310
  f"[RecDataLoader Error] Unsupported file extension in list: {suffix}"
312
311
  )
313
312
  file_type = "csv" if suffix == ".csv" else "parquet"
313
+ if streaming:
314
+ return self.load_files_streaming(
315
+ file_paths,
316
+ file_type,
317
+ batch_size,
318
+ chunk_size,
319
+ shuffle,
320
+ num_workers=num_workers,
321
+ )
314
322
  # Load full data into memory
315
- if load_full:
323
+ else:
316
324
  dfs = []
317
325
  total_bytes = 0
318
326
  for file_path in file_paths:
@@ -325,26 +333,17 @@ class RecDataLoader(FeatureSet):
325
333
  dfs.append(df)
326
334
  except MemoryError as exc:
327
335
  raise MemoryError(
328
- f"[RecDataLoader Error] Out of memory while reading {file_path}. Consider using load_full=False with streaming."
336
+ f"[RecDataLoader Error] Out of memory while reading {file_path}. Consider using streaming=True."
329
337
  ) from exc
330
338
  try:
331
339
  combined_df = pd.concat(dfs, ignore_index=True)
332
340
  except MemoryError as exc:
333
341
  raise MemoryError(
334
- f"[RecDataLoader Error] Out of memory while concatenating loaded data (approx {total_bytes / (1024**3):.2f} GB). Use load_full=False to stream or reduce chunk_size."
342
+ f"[RecDataLoader Error] Out of memory while concatenating loaded data (approx {total_bytes / (1024**3):.2f} GB). Use streaming=True or reduce chunk_size."
335
343
  ) from exc
336
344
  return self.create_from_memory(
337
345
  combined_df, batch_size, shuffle, num_workers=num_workers
338
346
  )
339
- else:
340
- return self.load_files_streaming(
341
- file_paths,
342
- file_type,
343
- batch_size,
344
- chunk_size,
345
- shuffle,
346
- num_workers=num_workers,
347
- )
348
347
 
349
348
  def load_files_streaming(
350
349
  self,
@@ -1,12 +1,11 @@
1
1
  """
2
2
  Date: create on 27/10/2025
3
3
  Checkpoint: edit on 24/11/2025
4
- Author:
5
- Yang Zhou,zyaztec@gmail.com
4
+ Author: Yang Zhou,zyaztec@gmail.com
6
5
  Reference:
7
- [1] Guo H, Tang R, Ye Y, et al. DeepFM: A factorization-machine based neural network
8
- for CTR prediction[J]. arXiv preprint arXiv:1703.04247, 2017.
9
- (https://arxiv.org/abs/1703.04247)
6
+ [1] Guo H, Tang R, Ye Y, et al. DeepFM: A factorization-machine based neural network
7
+ for CTR prediction[J]. arXiv preprint arXiv:1703.04247, 2017.
8
+ (https://arxiv.org/abs/1703.04247)
10
9
 
11
10
  DeepFM combines a Factorization Machine (FM) for explicit second-order feature
12
11
  interactions with a deep MLP for high-order nonlinear patterns. Both parts share