nextrec 0.4.34__tar.gz → 0.5.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (201) hide show
  1. {nextrec-0.4.34 → nextrec-0.5.1}/.gitignore +3 -0
  2. {nextrec-0.4.34 → nextrec-0.5.1}/PKG-INFO +10 -4
  3. {nextrec-0.4.34 → nextrec-0.5.1}/README.md +5 -3
  4. {nextrec-0.4.34 → nextrec-0.5.1}/README_en.md +5 -3
  5. {nextrec-0.4.34 → nextrec-0.5.1}/docs/rtd/conf.py +1 -1
  6. nextrec-0.5.1/nextrec/__version__.py +1 -0
  7. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/basic/activation.py +7 -13
  8. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/basic/layers.py +28 -94
  9. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/basic/model.py +512 -4
  10. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/cli.py +102 -20
  11. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/data/data_processing.py +8 -13
  12. nextrec-0.5.1/nextrec/data/preprocessor.py +1196 -0
  13. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/ranking/afm.py +4 -9
  14. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/ranking/dien.py +7 -8
  15. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/ranking/ffm.py +2 -2
  16. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/retrieval/sdm.py +1 -2
  17. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/sequential/hstu.py +0 -2
  18. nextrec-0.5.1/nextrec/utils/onnx_utils.py +252 -0
  19. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/utils/torch_utils.py +6 -1
  20. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec_cli_preset/predict_config.yaml +3 -1
  21. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec_cli_preset/predict_config_template.yaml +3 -0
  22. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec_cli_preset/train_config.yaml +5 -0
  23. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec_cli_preset/train_config_template.yaml +6 -0
  24. {nextrec-0.4.34 → nextrec-0.5.1}/pyproject.toml +5 -1
  25. {nextrec-0.4.34 → nextrec-0.5.1}/requirements.txt +4 -0
  26. {nextrec-0.4.34 → nextrec-0.5.1}/test/run_tests.py +4 -1
  27. nextrec-0.5.1/test/test_onnx_models.py +620 -0
  28. {nextrec-0.4.34 → nextrec-0.5.1}/test/test_preprocessor.py +8 -6
  29. nextrec-0.5.1/tutorials/distributed/example_distributed_training.py +234 -0
  30. nextrec-0.5.1/tutorials/distributed/example_distributed_training_large_dataset.py +251 -0
  31. nextrec-0.5.1/tutorials/example_match.py +261 -0
  32. nextrec-0.5.1/tutorials/example_multitask.py +221 -0
  33. nextrec-0.5.1/tutorials/example_onnx.py +300 -0
  34. nextrec-0.5.1/tutorials/example_ranking_din.py +217 -0
  35. nextrec-0.5.1/tutorials/example_tree.py +205 -0
  36. nextrec-0.5.1/tutorials/movielen_match_dssm.py +270 -0
  37. nextrec-0.5.1/tutorials/movielen_ranking_deepfm.py +183 -0
  38. nextrec-0.5.1/tutorials/run_all_match_models.py +303 -0
  39. nextrec-0.5.1/tutorials/run_all_multitask_models.py +395 -0
  40. nextrec-0.5.1/tutorials/run_all_ranking_models.py +388 -0
  41. nextrec-0.4.34/nextrec/__version__.py +0 -1
  42. nextrec-0.4.34/nextrec/data/preprocessor.py +0 -1593
  43. nextrec-0.4.34/nextrec/models/multi_task/[pre]star.py +0 -192
  44. nextrec-0.4.34/tutorials/distributed/example_distributed_training.py +0 -158
  45. nextrec-0.4.34/tutorials/distributed/example_distributed_training_large_dataset.py +0 -158
  46. nextrec-0.4.34/tutorials/example_match.py +0 -163
  47. nextrec-0.4.34/tutorials/example_multitask.py +0 -135
  48. nextrec-0.4.34/tutorials/example_ranking_din.py +0 -125
  49. nextrec-0.4.34/tutorials/example_tree.py +0 -97
  50. nextrec-0.4.34/tutorials/movielen_match_dssm.py +0 -155
  51. nextrec-0.4.34/tutorials/movielen_ranking_deepfm.py +0 -72
  52. nextrec-0.4.34/tutorials/run_all_match_models.py +0 -210
  53. nextrec-0.4.34/tutorials/run_all_multitask_models.py +0 -285
  54. nextrec-0.4.34/tutorials/run_all_ranking_models.py +0 -264
  55. {nextrec-0.4.34 → nextrec-0.5.1}/.github/workflows/publish.yml +0 -0
  56. {nextrec-0.4.34 → nextrec-0.5.1}/.github/workflows/tests.yml +0 -0
  57. {nextrec-0.4.34 → nextrec-0.5.1}/.readthedocs.yaml +0 -0
  58. {nextrec-0.4.34 → nextrec-0.5.1}/CODE_OF_CONDUCT.md +0 -0
  59. {nextrec-0.4.34 → nextrec-0.5.1}/CONTRIBUTING.md +0 -0
  60. {nextrec-0.4.34 → nextrec-0.5.1}/LICENSE +0 -0
  61. {nextrec-0.4.34 → nextrec-0.5.1}/MANIFEST.in +0 -0
  62. {nextrec-0.4.34 → nextrec-0.5.1}/assets/Feature Configuration.png +0 -0
  63. {nextrec-0.4.34 → nextrec-0.5.1}/assets/Model Parameters.png +0 -0
  64. {nextrec-0.4.34 → nextrec-0.5.1}/assets/Training Configuration.png +0 -0
  65. {nextrec-0.4.34 → nextrec-0.5.1}/assets/Training logs.png +0 -0
  66. {nextrec-0.4.34 → nextrec-0.5.1}/assets/logo.png +0 -0
  67. {nextrec-0.4.34 → nextrec-0.5.1}/assets/mmoe_tutorial.png +0 -0
  68. {nextrec-0.4.34 → nextrec-0.5.1}/assets/nextrec_diagram.png +0 -0
  69. {nextrec-0.4.34 → nextrec-0.5.1}/assets/test data.png +0 -0
  70. {nextrec-0.4.34 → nextrec-0.5.1}/dataset/ctcvr_task.csv +0 -0
  71. {nextrec-0.4.34 → nextrec-0.5.1}/dataset/ecommerce_task.csv +0 -0
  72. {nextrec-0.4.34 → nextrec-0.5.1}/dataset/match_task.csv +0 -0
  73. {nextrec-0.4.34 → nextrec-0.5.1}/dataset/movielens_100k.csv +0 -0
  74. {nextrec-0.4.34 → nextrec-0.5.1}/dataset/multitask_task.csv +0 -0
  75. {nextrec-0.4.34 → nextrec-0.5.1}/dataset/ranking_task.csv +0 -0
  76. {nextrec-0.4.34 → nextrec-0.5.1}/docs/en/Getting started guide.md +0 -0
  77. {nextrec-0.4.34 → nextrec-0.5.1}/docs/rtd/Makefile +0 -0
  78. {nextrec-0.4.34 → nextrec-0.5.1}/docs/rtd/index.md +0 -0
  79. {nextrec-0.4.34 → nextrec-0.5.1}/docs/rtd/make.bat +0 -0
  80. {nextrec-0.4.34 → nextrec-0.5.1}/docs/rtd/modules.rst +0 -0
  81. {nextrec-0.4.34 → nextrec-0.5.1}/docs/rtd/nextrec.basic.rst +0 -0
  82. {nextrec-0.4.34 → nextrec-0.5.1}/docs/rtd/nextrec.data.rst +0 -0
  83. {nextrec-0.4.34 → nextrec-0.5.1}/docs/rtd/nextrec.loss.rst +0 -0
  84. {nextrec-0.4.34 → nextrec-0.5.1}/docs/rtd/nextrec.rst +0 -0
  85. {nextrec-0.4.34 → nextrec-0.5.1}/docs/rtd/nextrec.utils.rst +0 -0
  86. {nextrec-0.4.34 → nextrec-0.5.1}/docs/rtd/requirements.txt +0 -0
  87. {nextrec-0.4.34 → nextrec-0.5.1}/docs/zh//345/277/253/351/200/237/344/270/212/346/211/213.md" +0 -0
  88. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/__init__.py +0 -0
  89. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/basic/__init__.py +0 -0
  90. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/basic/asserts.py +0 -0
  91. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/basic/callback.py +0 -0
  92. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/basic/features.py +0 -0
  93. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/basic/heads.py +0 -0
  94. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/basic/loggers.py +0 -0
  95. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/basic/metrics.py +0 -0
  96. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/basic/session.py +0 -0
  97. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/basic/summary.py +0 -0
  98. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/data/__init__.py +0 -0
  99. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/data/batch_utils.py +0 -0
  100. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/data/data_utils.py +0 -0
  101. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/data/dataloader.py +0 -0
  102. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/loss/__init__.py +0 -0
  103. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/loss/grad_norm.py +0 -0
  104. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/loss/listwise.py +0 -0
  105. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/loss/pairwise.py +0 -0
  106. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/loss/pointwise.py +0 -0
  107. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/generative/__init__.py +0 -0
  108. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/generative/tiger.py +0 -0
  109. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/multi_task/[pre]aitm.py +0 -0
  110. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/multi_task/[pre]snr_trans.py +0 -0
  111. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/multi_task/__init__.py +0 -0
  112. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/multi_task/apg.py +0 -0
  113. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/multi_task/cross_stitch.py +0 -0
  114. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/multi_task/escm.py +0 -0
  115. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/multi_task/esmm.py +0 -0
  116. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/multi_task/hmoe.py +0 -0
  117. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/multi_task/mmoe.py +0 -0
  118. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/multi_task/pepnet.py +0 -0
  119. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/multi_task/ple.py +0 -0
  120. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/multi_task/poso.py +0 -0
  121. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/multi_task/share_bottom.py +0 -0
  122. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/ranking/__init__.py +0 -0
  123. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/ranking/autoint.py +0 -0
  124. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/ranking/dcn.py +0 -0
  125. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/ranking/dcn_v2.py +0 -0
  126. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/ranking/deepfm.py +0 -0
  127. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/ranking/din.py +0 -0
  128. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/ranking/eulernet.py +0 -0
  129. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/ranking/fibinet.py +0 -0
  130. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/ranking/fm.py +0 -0
  131. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/ranking/lr.py +0 -0
  132. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/ranking/masknet.py +0 -0
  133. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/ranking/pnn.py +0 -0
  134. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/ranking/widedeep.py +0 -0
  135. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/ranking/xdeepfm.py +0 -0
  136. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/representation/__init__.py +0 -0
  137. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/representation/rqvae.py +0 -0
  138. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/retrieval/__init__.py +0 -0
  139. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/retrieval/dssm.py +0 -0
  140. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/retrieval/dssm_v2.py +0 -0
  141. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/retrieval/mind.py +0 -0
  142. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/retrieval/youtube_dnn.py +0 -0
  143. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/tree_base/__init__.py +0 -0
  144. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/tree_base/base.py +0 -0
  145. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/tree_base/catboost.py +0 -0
  146. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/tree_base/lightgbm.py +0 -0
  147. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/models/tree_base/xgboost.py +0 -0
  148. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/utils/__init__.py +0 -0
  149. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/utils/config.py +0 -0
  150. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/utils/console.py +0 -0
  151. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/utils/data.py +0 -0
  152. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/utils/embedding.py +0 -0
  153. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/utils/loss.py +0 -0
  154. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/utils/model.py +0 -0
  155. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec/utils/types.py +0 -0
  156. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec_cli_preset/NextRec-CLI.md +0 -0
  157. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec_cli_preset/NextRec-CLI_zh.md +0 -0
  158. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec_cli_preset/feature_config.yaml +0 -0
  159. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec_cli_preset/model_configs/afm.yaml +0 -0
  160. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec_cli_preset/model_configs/apg.yaml +0 -0
  161. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec_cli_preset/model_configs/autoint.yaml +0 -0
  162. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec_cli_preset/model_configs/cross_stitch.yaml +0 -0
  163. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec_cli_preset/model_configs/dcn.yaml +0 -0
  164. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec_cli_preset/model_configs/deepfm.yaml +0 -0
  165. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec_cli_preset/model_configs/din.yaml +0 -0
  166. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec_cli_preset/model_configs/escm.yaml +0 -0
  167. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec_cli_preset/model_configs/esmm.yaml +0 -0
  168. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec_cli_preset/model_configs/fibinet.yaml +0 -0
  169. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec_cli_preset/model_configs/fm.yaml +0 -0
  170. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec_cli_preset/model_configs/hmoe.yaml +0 -0
  171. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec_cli_preset/model_configs/masknet.yaml +0 -0
  172. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec_cli_preset/model_configs/mmoe.yaml +0 -0
  173. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec_cli_preset/model_configs/pepnet.yaml +0 -0
  174. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec_cli_preset/model_configs/ple.yaml +0 -0
  175. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec_cli_preset/model_configs/pnn.yaml +0 -0
  176. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec_cli_preset/model_configs/poso.yaml +0 -0
  177. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec_cli_preset/model_configs/share_bottom.yaml +0 -0
  178. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec_cli_preset/model_configs/widedeep.yaml +0 -0
  179. {nextrec-0.4.34 → nextrec-0.5.1}/nextrec_cli_preset/model_configs/xdeepfm.yaml +0 -0
  180. {nextrec-0.4.34 → nextrec-0.5.1}/pytest.ini +0 -0
  181. {nextrec-0.4.34 → nextrec-0.5.1}/scripts/format_code.py +0 -0
  182. {nextrec-0.4.34 → nextrec-0.5.1}/test/__init__.py +0 -0
  183. {nextrec-0.4.34 → nextrec-0.5.1}/test/conftest.py +0 -0
  184. {nextrec-0.4.34 → nextrec-0.5.1}/test/helpers.py +0 -0
  185. {nextrec-0.4.34 → nextrec-0.5.1}/test/test_base_model_regularization.py +0 -0
  186. {nextrec-0.4.34 → nextrec-0.5.1}/test/test_generative_models.py +0 -0
  187. {nextrec-0.4.34 → nextrec-0.5.1}/test/test_layers.py +0 -0
  188. {nextrec-0.4.34 → nextrec-0.5.1}/test/test_losses.py +0 -0
  189. {nextrec-0.4.34 → nextrec-0.5.1}/test/test_match_models.py +0 -0
  190. {nextrec-0.4.34 → nextrec-0.5.1}/test/test_multitask_models.py +0 -0
  191. {nextrec-0.4.34 → nextrec-0.5.1}/test/test_ranking_models.py +0 -0
  192. {nextrec-0.4.34 → nextrec-0.5.1}/test/test_utils_console.py +0 -0
  193. {nextrec-0.4.34 → nextrec-0.5.1}/test/test_utils_data.py +0 -0
  194. {nextrec-0.4.34 → nextrec-0.5.1}/test/test_utils_embedding.py +0 -0
  195. {nextrec-0.4.34 → nextrec-0.5.1}/test_requirements.txt +0 -0
  196. {nextrec-0.4.34 → nextrec-0.5.1}/tutorials/notebooks/en/Build semantic ID with RQ-VAE.ipynb +0 -0
  197. {nextrec-0.4.34 → nextrec-0.5.1}/tutorials/notebooks/en/Hands on dataprocessor.ipynb +0 -0
  198. {nextrec-0.4.34 → nextrec-0.5.1}/tutorials/notebooks/en/Hands on nextrec.ipynb +0 -0
  199. {nextrec-0.4.34 → nextrec-0.5.1}/tutorials/notebooks/zh//344/275/277/347/224/250RQ-VAE/346/236/204/345/273/272/350/257/255/344/271/211ID.ipynb" +0 -0
  200. {nextrec-0.4.34 → nextrec-0.5.1}/tutorials/notebooks/zh//345/246/202/344/275/225/344/275/277/347/224/250DataProcessor/350/277/233/350/241/214/351/242/204/345/244/204/347/220/206.ipynb" +0 -0
  201. {nextrec-0.4.34 → nextrec-0.5.1}/tutorials/notebooks/zh//345/277/253/351/200/237/345/205/245/351/227/250nextrec.ipynb" +0 -0
@@ -128,3 +128,6 @@ pypirc.template
128
128
 
129
129
  # Sphinx build
130
130
  docs/rtd/_build/
131
+
132
+ *.onnx
133
+ artifacts/
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nextrec
3
- Version: 0.4.34
3
+ Version: 0.5.1
4
4
  Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
5
5
  Project-URL: Homepage, https://github.com/zerolovesea/NextRec
6
6
  Project-URL: Repository, https://github.com/zerolovesea/NextRec
@@ -24,10 +24,14 @@ Requires-Dist: numpy<2.0,>=1.21; sys_platform == 'linux' and python_version < '3
24
24
  Requires-Dist: numpy<3.0,>=1.26; sys_platform == 'linux' and python_version >= '3.12'
25
25
  Requires-Dist: numpy>=1.23.0; sys_platform == 'win32'
26
26
  Requires-Dist: numpy>=1.24.0; sys_platform == 'darwin'
27
+ Requires-Dist: onnx>=1.16.0
28
+ Requires-Dist: onnxruntime>=1.18.0
29
+ Requires-Dist: onnxscript>=0.1.1
27
30
  Requires-Dist: pandas<2.0,>=1.5; sys_platform == 'linux' and python_version < '3.12'
28
31
  Requires-Dist: pandas<2.3.0,>=2.1.0; sys_platform == 'win32'
29
32
  Requires-Dist: pandas>=2.0.0; sys_platform == 'darwin'
30
33
  Requires-Dist: pandas>=2.1.0; sys_platform == 'linux' and python_version >= '3.12'
34
+ Requires-Dist: polars>=0.20.0
31
35
  Requires-Dist: pyarrow<13.0.0,>=10.0.0; sys_platform == 'linux' and python_version < '3.12'
32
36
  Requires-Dist: pyarrow<15.0.0,>=12.0.0; sys_platform == 'win32'
33
37
  Requires-Dist: pyarrow>=12.0.0; sys_platform == 'darwin'
@@ -69,7 +73,7 @@ Description-Content-Type: text/markdown
69
73
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
70
74
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
71
75
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
72
- ![Version](https://img.shields.io/badge/Version-0.4.34-orange.svg)
76
+ ![Version](https://img.shields.io/badge/Version-0.5.1-orange.svg)
73
77
  [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/zerolovesea/NextRec)
74
78
 
75
79
  中文文档 | [English Version](README_en.md)
@@ -102,6 +106,7 @@ NextRec是一个基于PyTorch的现代推荐系统框架,旨在为研究工程
102
106
  - **高效训练与评估**:内置多种优化器、学习率调度、早停、模型检查点与详细的日志管理,开箱即用。
103
107
 
104
108
  ## NextRec近期进展
109
+ - **28/01/2026** 在v0.4.39中加入了对onnx导出和加载的支持,并大大加速了数据预处理速度(最高9x加速)
105
110
  - **01/01/2026** 新年好,在v0.4.27中加入了多个多目标模型的支持:[APG](nextrec/models/multi_task/apg.py), [ESCM](nextrec/models/multi_task/escm.py), [HMoE](nextrec/models/multi_task/hmoe.py), [Cross Stitch](nextrec/models/multi_task/cross_stitch.py)
106
111
  - **28/12/2025** 在v0.4.21中加入了对SwanLab和Wandb的支持,通过model的`fit`方法进行配置:`use_swanlab=True, swanlab_kwargs={"project": "NextRec","name":"tutorial_movielens_deepfm"},`
107
112
  - **21/12/2025** 在v0.4.16中加入了对[GradNorm](/nextrec/loss/grad_norm.py)的支持,通过compile的`loss_weight='grad_norm'`进行配置
@@ -136,6 +141,7 @@ pip install nextrec # or pip install -e .
136
141
  - [example_multitask.py](/tutorials/example_multitask.py) - 电商数据集上的ESMM多任务学习训练示例
137
142
  - [movielen_match_dssm.py](/tutorials/movielen_match_dssm.py) - 基于movielen 100k数据集训练的 DSSM 召回模型示例
138
143
 
144
+ - [example_onnx.py](/tutorials/example_onnx.py) - 使用NextRec训练和导出onnx模型
139
145
  - [example_distributed_training.py](/tutorials/distributed/example_distributed_training.py) - 使用NextRec进行单机多卡训练的代码示例
140
146
 
141
147
  - [run_all_ranking_models.py](/tutorials/run_all_ranking_models.py) - 快速校验所有排序模型的可用性
@@ -254,11 +260,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
254
260
 
255
261
  预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
256
262
 
257
- > 截止当前版本0.4.34,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
263
+ > 截止当前版本0.5.1,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
258
264
 
259
265
  ## 兼容平台
260
266
 
261
- 当前最新版本为0.4.34,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
267
+ 当前最新版本为0.5.1,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
262
268
 
263
269
  | 平台 | 配置 |
264
270
  |------|------|
@@ -8,7 +8,7 @@
8
8
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
9
9
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
10
10
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
11
- ![Version](https://img.shields.io/badge/Version-0.4.34-orange.svg)
11
+ ![Version](https://img.shields.io/badge/Version-0.5.1-orange.svg)
12
12
  [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/zerolovesea/NextRec)
13
13
 
14
14
  中文文档 | [English Version](README_en.md)
@@ -41,6 +41,7 @@ NextRec是一个基于PyTorch的现代推荐系统框架,旨在为研究工程
41
41
  - **高效训练与评估**:内置多种优化器、学习率调度、早停、模型检查点与详细的日志管理,开箱即用。
42
42
 
43
43
  ## NextRec近期进展
44
+ - **28/01/2026** 在v0.4.39中加入了对onnx导出和加载的支持,并大大加速了数据预处理速度(最高9x加速)
44
45
  - **01/01/2026** 新年好,在v0.4.27中加入了多个多目标模型的支持:[APG](nextrec/models/multi_task/apg.py), [ESCM](nextrec/models/multi_task/escm.py), [HMoE](nextrec/models/multi_task/hmoe.py), [Cross Stitch](nextrec/models/multi_task/cross_stitch.py)
45
46
  - **28/12/2025** 在v0.4.21中加入了对SwanLab和Wandb的支持,通过model的`fit`方法进行配置:`use_swanlab=True, swanlab_kwargs={"project": "NextRec","name":"tutorial_movielens_deepfm"},`
46
47
  - **21/12/2025** 在v0.4.16中加入了对[GradNorm](/nextrec/loss/grad_norm.py)的支持,通过compile的`loss_weight='grad_norm'`进行配置
@@ -75,6 +76,7 @@ pip install nextrec # or pip install -e .
75
76
  - [example_multitask.py](/tutorials/example_multitask.py) - 电商数据集上的ESMM多任务学习训练示例
76
77
  - [movielen_match_dssm.py](/tutorials/movielen_match_dssm.py) - 基于movielen 100k数据集训练的 DSSM 召回模型示例
77
78
 
79
+ - [example_onnx.py](/tutorials/example_onnx.py) - 使用NextRec训练和导出onnx模型
78
80
  - [example_distributed_training.py](/tutorials/distributed/example_distributed_training.py) - 使用NextRec进行单机多卡训练的代码示例
79
81
 
80
82
  - [run_all_ranking_models.py](/tutorials/run_all_ranking_models.py) - 快速校验所有排序模型的可用性
@@ -193,11 +195,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
193
195
 
194
196
  预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
195
197
 
196
- > 截止当前版本0.4.34,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
198
+ > 截止当前版本0.5.1,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
197
199
 
198
200
  ## 兼容平台
199
201
 
200
- 当前最新版本为0.4.34,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
202
+ 当前最新版本为0.5.1,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
201
203
 
202
204
  | 平台 | 配置 |
203
205
  |------|------|
@@ -8,7 +8,7 @@
8
8
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
9
9
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
10
10
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
11
- ![Version](https://img.shields.io/badge/Version-0.4.34-orange.svg)
11
+ ![Version](https://img.shields.io/badge/Version-0.5.1-orange.svg)
12
12
  [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/zerolovesea/NextRec)
13
13
 
14
14
  English | [中文文档](README.md)
@@ -44,6 +44,7 @@ NextRec is a modern recommendation framework built on PyTorch, delivering a unif
44
44
 
45
45
  ## NextRec Progress
46
46
 
47
+ - **28/01/2026** Added support for ONNX export and loading in v0.4.39, and significantly accelerated data preprocessing speed (up to 9x speedup)
47
48
  - **01/01/2026** Happy New Year! In v0.4.27, added support for multiple multi-task models: [APG](/nextrec/models/multi_task/apg.py), [ESCM](/nextrec/models/multi_task/escm.py), [HMoE](/nextrec/models/multi_task/hmoe.py), [Cross Stitch](/nextrec/models/multi_task/cross_stitch.py)
48
49
  - **28/12/2025** Added support for SwanLab and Weights & Biases in v0.4.21, configurable via the model `fit` method: `use_swanlab=True, swanlab_kwargs={"project": "NextRec","name":"tutorial_movielens_deepfm"},`
49
50
  - **21/12/2025** Added support for [GradNorm](/nextrec/loss/grad_norm.py) in v0.4.16, configurable via `loss_weight='grad_norm'` in the compile method
@@ -79,6 +80,7 @@ See `tutorials/` for examples covering ranking, retrieval, multi-task learning,
79
80
  - [example_multitask.py](/tutorials/example_multitask.py) — ESMM multi-task learning training on e-commerce dataset
80
81
  - [movielen_match_dssm.py](/tutorials/movielen_match_dssm.py) — DSSM retrieval model training on MovieLens 100k dataset
81
82
 
83
+ - [example_onnx.py](/tutorials/example_onnx.py) — Train and export models to ONNX format with NextRec
82
84
  - [example_distributed_training.py](/tutorials/distributed/example_distributed_training.py) — Single-machine multi-GPU training with NextRec
83
85
 
84
86
  - [run_all_ranking_models.py](/tutorials/run_all_ranking_models.py) — Quickly validate availability of all ranking models
@@ -196,11 +198,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
196
198
 
197
199
  Prediction outputs are saved under `{checkpoint_path}/predictions/{name}.{save_data_format}`.
198
200
 
199
- > As of version 0.4.34, NextRec CLI supports single-machine training; distributed training features are currently under development.
201
+ > As of version 0.5.1, NextRec CLI supports single-machine training; distributed training features are currently under development.
200
202
 
201
203
  ## Platform Compatibility
202
204
 
203
- The current version is 0.4.34. All models and test code have been validated on the following platforms. If you encounter compatibility issues, please report them in the issue tracker with your system version:
205
+ The current version is 0.5.1. All models and test code have been validated on the following platforms. If you encounter compatibility issues, please report them in the issue tracker with your system version:
204
206
 
205
207
  | Platform | Configuration |
206
208
  |----------|---------------|
@@ -11,7 +11,7 @@ sys.path.insert(0, str(PROJECT_ROOT / "nextrec"))
11
11
  project = "NextRec"
12
12
  copyright = "2026, Yang Zhou"
13
13
  author = "Yang Zhou"
14
- release = "0.4.34"
14
+ release = "0.5.1"
15
15
 
16
16
  extensions = [
17
17
  "myst_parser",
@@ -0,0 +1 @@
1
+ __version__ = "0.5.1"
@@ -25,21 +25,15 @@ class Dice(nn.Module):
25
25
  def __init__(self, emb_size: int, epsilon: float = 1e-3):
26
26
  super(Dice, self).__init__()
27
27
  self.alpha = nn.Parameter(torch.zeros(emb_size))
28
- self.bn = nn.BatchNorm1d(emb_size, eps=epsilon)
28
+ self.bn = nn.BatchNorm1d(emb_size, eps=epsilon, affine=False)
29
29
 
30
30
  def forward(self, x):
31
- # x shape: (batch_size, emb_size) or (batch_size, seq_len, emb_size)
32
- if x.dim() == 2: # (B, E)
33
- x_norm = self.bn(x)
34
- p = torch.sigmoid(x_norm)
35
- return x * (self.alpha + (1 - self.alpha) * p)
36
-
37
- if x.dim() == 3: # (B, T, E)
38
- b, t, e = x.shape
39
- x2 = x.reshape(-1, e) # (B*T, E)
40
- x_norm = self.bn(x2)
41
- p = torch.sigmoid(x_norm).reshape(b, t, e)
42
- return x * (self.alpha + (1 - self.alpha) * p)
31
+ # keep original shape for reshaping back after batch norm
32
+ orig_shape = x.shape # x: [N, L, emb_size] or [N, emb_size]
33
+ x2 = x.reshape(-1, orig_shape[-1]) # x2:[N*L, emb_size] or [N, emb_size]
34
+ x_norm = self.bn(x2)
35
+ p = torch.sigmoid(x_norm).reshape(orig_shape)
36
+ return x * (self.alpha + (1 - self.alpha) * p)
43
37
 
44
38
 
45
39
  def activation_layer(
@@ -2,7 +2,7 @@
2
2
  Layer implementations used across NextRec.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 22/01/2026
5
+ Checkpoint: edit on 25/01/2026
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -79,10 +79,12 @@ class PredictionLayer(nn.Module):
79
79
  def forward(self, x: torch.Tensor) -> torch.Tensor:
80
80
  if x.dim() == 1:
81
81
  x = x.unsqueeze(0) # (1 * total_dim)
82
- if x.shape[-1] != self.total_dim:
83
- raise ValueError(
84
- f"[PredictionLayer Error]: Input last dimension ({x.shape[-1]}) does not match expected total dimension ({self.total_dim})."
85
- )
82
+ if not torch.onnx.is_in_onnx_export():
83
+ if x.shape[-1] != self.total_dim:
84
+ raise ValueError(
85
+ f"[PredictionLayer Error]: Input last dimension ({x.shape[-1]}) does not match expected total dimension ({self.total_dim})."
86
+ )
87
+
86
88
  logits = x if self.bias is None else x + self.bias
87
89
  outputs = []
88
90
  for task_type, (start, end) in zip(self.task_types, self.task_slices):
@@ -216,7 +218,7 @@ class EmbeddingLayer(nn.Module):
216
218
 
217
219
  elif isinstance(feature, SequenceFeature):
218
220
  seq_input = x[feature.name].long()
219
- if feature.max_len is not None and seq_input.size(1) > feature.max_len:
221
+ if feature.max_len is not None:
220
222
  seq_input = seq_input[:, -feature.max_len :]
221
223
 
222
224
  embed = self.embed_dict[feature.embedding_name]
@@ -279,10 +281,11 @@ class EmbeddingLayer(nn.Module):
279
281
  value = value.view(value.size(0), -1) # [B, input_dim]
280
282
  input_dim = feature.input_dim
281
283
  assert_input_dim = self.dense_input_dims.get(feature.name, input_dim)
282
- if value.shape[1] != assert_input_dim:
283
- raise ValueError(
284
- f"[EmbeddingLayer Error]:Dense feature '{feature.name}' expects {assert_input_dim} inputs but got {value.shape[1]}."
285
- )
284
+ if not torch.onnx.is_in_onnx_export():
285
+ if value.shape[1] != assert_input_dim:
286
+ raise ValueError(
287
+ f"[EmbeddingLayer Error]:Dense feature '{feature.name}' expects {assert_input_dim} inputs but got {value.shape[1]}."
288
+ )
286
289
  if not feature.use_projection:
287
290
  return value
288
291
  dense_layer = self.dense_transforms[feature.name]
@@ -328,29 +331,10 @@ class InputMask(nn.Module):
328
331
  feature: SequenceFeature,
329
332
  seq_tensor: torch.Tensor | None = None,
330
333
  ):
331
- if seq_tensor is not None:
332
- values = seq_tensor
333
- else:
334
- values = x[feature.name]
335
- values = values.long()
334
+ values = seq_tensor if seq_tensor is not None else x[feature.name]
335
+ values = values.long().view(values.size(0), -1)
336
336
  padding_idx = feature.padding_idx if feature.padding_idx is not None else 0
337
- mask = values != padding_idx
338
-
339
- if mask.dim() == 1:
340
- # [B] -> [B, 1, 1]
341
- mask = mask.unsqueeze(1).unsqueeze(2)
342
- elif mask.dim() == 2:
343
- # [B, L] -> [B, 1, L]
344
- mask = mask.unsqueeze(1)
345
- elif mask.dim() == 3:
346
- # [B, 1, L]
347
- # [B, L, 1] -> [B, L] -> [B, 1, L]
348
- if mask.size(1) != 1 and mask.size(2) == 1:
349
- mask = mask.squeeze(-1).unsqueeze(1)
350
- else:
351
- raise ValueError(
352
- f"InputMask only supports 1D/2D/3D tensors, got shape {values.shape}"
353
- )
337
+ mask = (values != padding_idx).unsqueeze(1)
354
338
  return mask.float()
355
339
 
356
340
 
@@ -928,39 +912,22 @@ class AttentionPoolingLayer(nn.Module):
928
912
  output: [batch_size, embedding_dim] - attention pooled representation
929
913
  """
930
914
  batch_size, sequence_length, embedding_dim = keys.shape
931
- assert query.shape == (
932
- batch_size,
933
- embedding_dim,
934
- ), f"query shape {query.shape} != ({batch_size}, {embedding_dim})"
935
- if mask is None and keys_length is not None:
936
- # keys_length: (batch_size,)
937
- device = keys.device
938
- seq_range = torch.arange(sequence_length, device=device).unsqueeze(
939
- 0
940
- ) # (1, sequence_length)
941
- mask = (seq_range < keys_length.unsqueeze(1)).unsqueeze(-1).float()
942
- if mask is not None:
943
- if mask.dim() == 2:
944
- # (B, L)
945
- mask = mask.unsqueeze(-1)
946
- elif (
947
- mask.dim() == 3
948
- and mask.shape[1] == 1
949
- and mask.shape[2] == sequence_length
950
- ):
951
- # (B, 1, L) -> (B, L, 1)
952
- mask = mask.transpose(1, 2)
953
- elif (
954
- mask.dim() == 3
955
- and mask.shape[1] == sequence_length
956
- and mask.shape[2] == 1
957
- ):
958
- pass
915
+ if mask is None:
916
+ if keys_length is None:
917
+ mask = torch.ones(
918
+ (batch_size, sequence_length), device=keys.device, dtype=keys.dtype
919
+ )
959
920
  else:
921
+ device = keys.device
922
+ seq_range = torch.arange(sequence_length, device=device).unsqueeze(0)
923
+ mask = (seq_range < keys_length.unsqueeze(1)).to(keys.dtype)
924
+ else:
925
+ mask = mask.to(keys.dtype).reshape(batch_size, -1)
926
+ if mask.shape[1] != sequence_length:
960
927
  raise ValueError(
961
928
  f"[AttentionPoolingLayer Error]: Unsupported mask shape: {mask.shape}"
962
929
  )
963
- mask = mask.to(keys.dtype)
930
+ mask = mask.unsqueeze(-1)
964
931
  # Expand query to (B, L, D)
965
932
  query_expanded = query.unsqueeze(1).expand(-1, sequence_length, -1)
966
933
  # [query, key, query-key, query*key] -> (B, L, 4D)
@@ -1000,36 +967,3 @@ class RMSNorm(torch.nn.Module):
1000
967
  variance = torch.mean(x**2, dim=-1, keepdim=True)
1001
968
  x_normalized = x * torch.rsqrt(variance + self.eps)
1002
969
  return self.weight * x_normalized
1003
-
1004
-
1005
- class DomainBatchNorm(nn.Module):
1006
- """
1007
- Domain-specific BatchNorm (applied per-domain with a shared interface).
1008
- """
1009
-
1010
- def __init__(self, num_features: int, num_domains: int):
1011
- super().__init__()
1012
- if num_domains < 1:
1013
- raise ValueError("num_domains must be >= 1")
1014
- self.bns = nn.ModuleList(
1015
- [nn.BatchNorm1d(num_features) for _ in range(num_domains)]
1016
- )
1017
-
1018
- def forward(self, x: torch.Tensor, domain_mask: torch.Tensor) -> torch.Tensor:
1019
- if x.dim() != 2:
1020
- raise ValueError("DomainBatchNorm expects 2D inputs [B, D].")
1021
- output = x.clone()
1022
- if domain_mask.dim() == 1:
1023
- domain_ids = domain_mask.long()
1024
- for idx, bn in enumerate(self.bns):
1025
- mask = domain_ids == idx
1026
- if mask.any():
1027
- output[mask] = bn(x[mask])
1028
- return output
1029
- if domain_mask.dim() != 2:
1030
- raise ValueError("domain_mask must be 1D indices or 2D one-hot mask.")
1031
- for idx, bn in enumerate(self.bns):
1032
- mask = domain_mask[:, idx] > 0
1033
- if mask.any():
1034
- output[mask] = bn(x[mask])
1035
- return output