nextrec 0.4.33__tar.gz → 0.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (209) hide show
  1. {nextrec-0.4.33 → nextrec-0.5.0}/.gitignore +3 -0
  2. {nextrec-0.4.33 → nextrec-0.5.0}/PKG-INFO +10 -4
  3. {nextrec-0.4.33 → nextrec-0.5.0}/README.md +5 -3
  4. {nextrec-0.4.33 → nextrec-0.5.0}/README_en.md +5 -3
  5. {nextrec-0.4.33 → nextrec-0.5.0}/docs/rtd/conf.py +1 -1
  6. {nextrec-0.4.33 → nextrec-0.5.0}/docs/rtd/nextrec.utils.rst +0 -8
  7. nextrec-0.5.0/nextrec/__version__.py +1 -0
  8. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/basic/activation.py +10 -18
  9. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/basic/asserts.py +1 -22
  10. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/basic/callback.py +2 -2
  11. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/basic/features.py +6 -37
  12. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/basic/heads.py +13 -1
  13. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/basic/layers.py +33 -123
  14. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/basic/loggers.py +3 -2
  15. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/basic/metrics.py +85 -4
  16. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/basic/model.py +518 -7
  17. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/basic/summary.py +88 -42
  18. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/cli.py +117 -30
  19. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/data/data_processing.py +8 -13
  20. nextrec-0.5.0/nextrec/data/preprocessor.py +1196 -0
  21. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/loss/grad_norm.py +78 -76
  22. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/multi_task/ple.py +1 -0
  23. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/multi_task/share_bottom.py +1 -0
  24. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/afm.py +4 -9
  25. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/dien.py +7 -8
  26. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/ffm.py +2 -2
  27. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/retrieval/sdm.py +1 -2
  28. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/sequential/hstu.py +0 -2
  29. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/tree_base/base.py +1 -1
  30. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/utils/__init__.py +2 -1
  31. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/utils/config.py +1 -1
  32. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/utils/console.py +1 -1
  33. nextrec-0.5.0/nextrec/utils/onnx_utils.py +252 -0
  34. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/utils/torch_utils.py +63 -56
  35. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/utils/types.py +43 -0
  36. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/NextRec-CLI.md +0 -2
  37. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/NextRec-CLI_zh.md +0 -2
  38. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/predict_config.yaml +6 -3
  39. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/predict_config_template.yaml +6 -2
  40. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/train_config.yaml +5 -0
  41. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/train_config_template.yaml +6 -2
  42. {nextrec-0.4.33 → nextrec-0.5.0}/pyproject.toml +5 -1
  43. {nextrec-0.4.33 → nextrec-0.5.0}/requirements.txt +4 -0
  44. {nextrec-0.4.33 → nextrec-0.5.0}/test/run_tests.py +5 -1
  45. nextrec-0.5.0/test/test_onnx_models.py +620 -0
  46. {nextrec-0.4.33 → nextrec-0.5.0}/test/test_preprocessor.py +8 -6
  47. nextrec-0.5.0/tutorials/distributed/example_distributed_training.py +234 -0
  48. nextrec-0.5.0/tutorials/distributed/example_distributed_training_large_dataset.py +251 -0
  49. nextrec-0.5.0/tutorials/example_match.py +261 -0
  50. nextrec-0.5.0/tutorials/example_multitask.py +221 -0
  51. nextrec-0.5.0/tutorials/example_onnx.py +300 -0
  52. nextrec-0.5.0/tutorials/example_ranking_din.py +217 -0
  53. nextrec-0.5.0/tutorials/example_tree.py +205 -0
  54. nextrec-0.5.0/tutorials/movielen_match_dssm.py +270 -0
  55. nextrec-0.5.0/tutorials/movielen_ranking_deepfm.py +183 -0
  56. nextrec-0.5.0/tutorials/run_all_match_models.py +303 -0
  57. nextrec-0.5.0/tutorials/run_all_multitask_models.py +396 -0
  58. nextrec-0.5.0/tutorials/run_all_ranking_models.py +388 -0
  59. nextrec-0.4.33/nextrec/__version__.py +0 -1
  60. nextrec-0.4.33/nextrec/data/preprocessor.py +0 -1591
  61. nextrec-0.4.33/nextrec/models/multi_task/[pre]star.py +0 -192
  62. nextrec-0.4.33/nextrec/models/representation/autorec.py +0 -0
  63. nextrec-0.4.33/nextrec/models/representation/bpr.py +0 -0
  64. nextrec-0.4.33/nextrec/models/representation/cl4srec.py +0 -0
  65. nextrec-0.4.33/nextrec/models/representation/lightgcn.py +0 -0
  66. nextrec-0.4.33/nextrec/models/representation/mf.py +0 -0
  67. nextrec-0.4.33/nextrec/models/representation/s3rec.py +0 -0
  68. nextrec-0.4.33/nextrec/models/sequential/sasrec.py +0 -0
  69. nextrec-0.4.33/nextrec/utils/feature.py +0 -29
  70. nextrec-0.4.33/tutorials/distributed/example_distributed_training.py +0 -158
  71. nextrec-0.4.33/tutorials/distributed/example_distributed_training_large_dataset.py +0 -158
  72. nextrec-0.4.33/tutorials/example_match.py +0 -164
  73. nextrec-0.4.33/tutorials/example_multitask.py +0 -122
  74. nextrec-0.4.33/tutorials/example_ranking_din.py +0 -125
  75. nextrec-0.4.33/tutorials/example_tree.py +0 -97
  76. nextrec-0.4.33/tutorials/movielen_match_dssm.py +0 -155
  77. nextrec-0.4.33/tutorials/movielen_ranking_deepfm.py +0 -73
  78. nextrec-0.4.33/tutorials/run_all_match_models.py +0 -210
  79. nextrec-0.4.33/tutorials/run_all_multitask_models.py +0 -285
  80. nextrec-0.4.33/tutorials/run_all_ranking_models.py +0 -264
  81. {nextrec-0.4.33 → nextrec-0.5.0}/.github/workflows/publish.yml +0 -0
  82. {nextrec-0.4.33 → nextrec-0.5.0}/.github/workflows/tests.yml +0 -0
  83. {nextrec-0.4.33 → nextrec-0.5.0}/.readthedocs.yaml +0 -0
  84. {nextrec-0.4.33 → nextrec-0.5.0}/CODE_OF_CONDUCT.md +0 -0
  85. {nextrec-0.4.33 → nextrec-0.5.0}/CONTRIBUTING.md +0 -0
  86. {nextrec-0.4.33 → nextrec-0.5.0}/LICENSE +0 -0
  87. {nextrec-0.4.33 → nextrec-0.5.0}/MANIFEST.in +0 -0
  88. {nextrec-0.4.33 → nextrec-0.5.0}/assets/Feature Configuration.png +0 -0
  89. {nextrec-0.4.33 → nextrec-0.5.0}/assets/Model Parameters.png +0 -0
  90. {nextrec-0.4.33 → nextrec-0.5.0}/assets/Training Configuration.png +0 -0
  91. {nextrec-0.4.33 → nextrec-0.5.0}/assets/Training logs.png +0 -0
  92. {nextrec-0.4.33 → nextrec-0.5.0}/assets/logo.png +0 -0
  93. {nextrec-0.4.33 → nextrec-0.5.0}/assets/mmoe_tutorial.png +0 -0
  94. {nextrec-0.4.33 → nextrec-0.5.0}/assets/nextrec_diagram.png +0 -0
  95. {nextrec-0.4.33 → nextrec-0.5.0}/assets/test data.png +0 -0
  96. {nextrec-0.4.33 → nextrec-0.5.0}/dataset/ctcvr_task.csv +0 -0
  97. {nextrec-0.4.33 → nextrec-0.5.0}/dataset/ecommerce_task.csv +0 -0
  98. {nextrec-0.4.33 → nextrec-0.5.0}/dataset/match_task.csv +0 -0
  99. {nextrec-0.4.33 → nextrec-0.5.0}/dataset/movielens_100k.csv +0 -0
  100. {nextrec-0.4.33 → nextrec-0.5.0}/dataset/multitask_task.csv +0 -0
  101. {nextrec-0.4.33 → nextrec-0.5.0}/dataset/ranking_task.csv +0 -0
  102. {nextrec-0.4.33 → nextrec-0.5.0}/docs/en/Getting started guide.md +0 -0
  103. {nextrec-0.4.33 → nextrec-0.5.0}/docs/rtd/Makefile +0 -0
  104. {nextrec-0.4.33 → nextrec-0.5.0}/docs/rtd/index.md +0 -0
  105. {nextrec-0.4.33 → nextrec-0.5.0}/docs/rtd/make.bat +0 -0
  106. {nextrec-0.4.33 → nextrec-0.5.0}/docs/rtd/modules.rst +0 -0
  107. {nextrec-0.4.33 → nextrec-0.5.0}/docs/rtd/nextrec.basic.rst +0 -0
  108. {nextrec-0.4.33 → nextrec-0.5.0}/docs/rtd/nextrec.data.rst +0 -0
  109. {nextrec-0.4.33 → nextrec-0.5.0}/docs/rtd/nextrec.loss.rst +0 -0
  110. {nextrec-0.4.33 → nextrec-0.5.0}/docs/rtd/nextrec.rst +0 -0
  111. {nextrec-0.4.33 → nextrec-0.5.0}/docs/rtd/requirements.txt +0 -0
  112. {nextrec-0.4.33 → nextrec-0.5.0}/docs/zh//345/277/253/351/200/237/344/270/212/346/211/213.md" +0 -0
  113. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/__init__.py +0 -0
  114. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/basic/__init__.py +0 -0
  115. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/basic/session.py +0 -0
  116. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/data/__init__.py +0 -0
  117. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/data/batch_utils.py +0 -0
  118. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/data/data_utils.py +0 -0
  119. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/data/dataloader.py +0 -0
  120. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/loss/__init__.py +0 -0
  121. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/loss/listwise.py +0 -0
  122. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/loss/pairwise.py +0 -0
  123. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/loss/pointwise.py +0 -0
  124. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/generative/__init__.py +0 -0
  125. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/generative/tiger.py +0 -0
  126. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/multi_task/[pre]aitm.py +0 -0
  127. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/multi_task/[pre]snr_trans.py +0 -0
  128. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/multi_task/__init__.py +0 -0
  129. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/multi_task/apg.py +0 -0
  130. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/multi_task/cross_stitch.py +0 -0
  131. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/multi_task/escm.py +0 -0
  132. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/multi_task/esmm.py +0 -0
  133. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/multi_task/hmoe.py +0 -0
  134. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/multi_task/mmoe.py +0 -0
  135. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/multi_task/pepnet.py +0 -0
  136. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/multi_task/poso.py +0 -0
  137. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/__init__.py +0 -0
  138. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/autoint.py +0 -0
  139. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/dcn.py +0 -0
  140. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/dcn_v2.py +0 -0
  141. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/deepfm.py +0 -0
  142. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/din.py +0 -0
  143. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/eulernet.py +0 -0
  144. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/fibinet.py +0 -0
  145. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/fm.py +0 -0
  146. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/lr.py +0 -0
  147. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/masknet.py +0 -0
  148. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/pnn.py +0 -0
  149. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/widedeep.py +0 -0
  150. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/xdeepfm.py +0 -0
  151. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/representation/__init__.py +0 -0
  152. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/representation/rqvae.py +0 -0
  153. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/retrieval/__init__.py +0 -0
  154. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/retrieval/dssm.py +0 -0
  155. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/retrieval/dssm_v2.py +0 -0
  156. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/retrieval/mind.py +0 -0
  157. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/retrieval/youtube_dnn.py +0 -0
  158. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/tree_base/__init__.py +0 -0
  159. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/tree_base/catboost.py +0 -0
  160. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/tree_base/lightgbm.py +0 -0
  161. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/tree_base/xgboost.py +0 -0
  162. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/utils/data.py +0 -0
  163. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/utils/embedding.py +0 -0
  164. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/utils/loss.py +0 -0
  165. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/utils/model.py +0 -0
  166. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/feature_config.yaml +0 -0
  167. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/afm.yaml +0 -0
  168. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/apg.yaml +0 -0
  169. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/autoint.yaml +0 -0
  170. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/cross_stitch.yaml +0 -0
  171. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/dcn.yaml +0 -0
  172. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/deepfm.yaml +0 -0
  173. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/din.yaml +0 -0
  174. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/escm.yaml +0 -0
  175. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/esmm.yaml +0 -0
  176. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/fibinet.yaml +0 -0
  177. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/fm.yaml +0 -0
  178. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/hmoe.yaml +0 -0
  179. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/masknet.yaml +0 -0
  180. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/mmoe.yaml +0 -0
  181. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/pepnet.yaml +0 -0
  182. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/ple.yaml +0 -0
  183. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/pnn.yaml +0 -0
  184. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/poso.yaml +0 -0
  185. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/share_bottom.yaml +0 -0
  186. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/widedeep.yaml +0 -0
  187. {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/xdeepfm.yaml +0 -0
  188. {nextrec-0.4.33 → nextrec-0.5.0}/pytest.ini +0 -0
  189. {nextrec-0.4.33 → nextrec-0.5.0}/scripts/format_code.py +0 -0
  190. {nextrec-0.4.33 → nextrec-0.5.0}/test/__init__.py +0 -0
  191. {nextrec-0.4.33 → nextrec-0.5.0}/test/conftest.py +0 -0
  192. {nextrec-0.4.33 → nextrec-0.5.0}/test/helpers.py +0 -0
  193. {nextrec-0.4.33 → nextrec-0.5.0}/test/test_base_model_regularization.py +0 -0
  194. {nextrec-0.4.33 → nextrec-0.5.0}/test/test_generative_models.py +0 -0
  195. {nextrec-0.4.33 → nextrec-0.5.0}/test/test_layers.py +0 -0
  196. {nextrec-0.4.33 → nextrec-0.5.0}/test/test_losses.py +0 -0
  197. {nextrec-0.4.33 → nextrec-0.5.0}/test/test_match_models.py +0 -0
  198. {nextrec-0.4.33 → nextrec-0.5.0}/test/test_multitask_models.py +0 -0
  199. {nextrec-0.4.33 → nextrec-0.5.0}/test/test_ranking_models.py +0 -0
  200. {nextrec-0.4.33 → nextrec-0.5.0}/test/test_utils_console.py +0 -0
  201. {nextrec-0.4.33 → nextrec-0.5.0}/test/test_utils_data.py +0 -0
  202. {nextrec-0.4.33 → nextrec-0.5.0}/test/test_utils_embedding.py +0 -0
  203. {nextrec-0.4.33 → nextrec-0.5.0}/test_requirements.txt +0 -0
  204. {nextrec-0.4.33 → nextrec-0.5.0}/tutorials/notebooks/en/Build semantic ID with RQ-VAE.ipynb +0 -0
  205. {nextrec-0.4.33 → nextrec-0.5.0}/tutorials/notebooks/en/Hands on dataprocessor.ipynb +0 -0
  206. {nextrec-0.4.33 → nextrec-0.5.0}/tutorials/notebooks/en/Hands on nextrec.ipynb +0 -0
  207. {nextrec-0.4.33 → nextrec-0.5.0}/tutorials/notebooks/zh//344/275/277/347/224/250RQ-VAE/346/236/204/345/273/272/350/257/255/344/271/211ID.ipynb" +0 -0
  208. {nextrec-0.4.33 → nextrec-0.5.0}/tutorials/notebooks/zh//345/246/202/344/275/225/344/275/277/347/224/250DataProcessor/350/277/233/350/241/214/351/242/204/345/244/204/347/220/206.ipynb" +0 -0
  209. {nextrec-0.4.33 → nextrec-0.5.0}/tutorials/notebooks/zh//345/277/253/351/200/237/345/205/245/351/227/250nextrec.ipynb" +0 -0
@@ -128,3 +128,6 @@ pypirc.template
128
128
 
129
129
  # Sphinx build
130
130
  docs/rtd/_build/
131
+
132
+ *.onnx
133
+ artifacts/
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nextrec
3
- Version: 0.4.33
3
+ Version: 0.5.0
4
4
  Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
5
5
  Project-URL: Homepage, https://github.com/zerolovesea/NextRec
6
6
  Project-URL: Repository, https://github.com/zerolovesea/NextRec
@@ -24,10 +24,14 @@ Requires-Dist: numpy<2.0,>=1.21; sys_platform == 'linux' and python_version < '3
24
24
  Requires-Dist: numpy<3.0,>=1.26; sys_platform == 'linux' and python_version >= '3.12'
25
25
  Requires-Dist: numpy>=1.23.0; sys_platform == 'win32'
26
26
  Requires-Dist: numpy>=1.24.0; sys_platform == 'darwin'
27
+ Requires-Dist: onnx>=1.16.0
28
+ Requires-Dist: onnxruntime>=1.18.0
29
+ Requires-Dist: onnxscript>=0.1.1
27
30
  Requires-Dist: pandas<2.0,>=1.5; sys_platform == 'linux' and python_version < '3.12'
28
31
  Requires-Dist: pandas<2.3.0,>=2.1.0; sys_platform == 'win32'
29
32
  Requires-Dist: pandas>=2.0.0; sys_platform == 'darwin'
30
33
  Requires-Dist: pandas>=2.1.0; sys_platform == 'linux' and python_version >= '3.12'
34
+ Requires-Dist: polars>=0.20.0
31
35
  Requires-Dist: pyarrow<13.0.0,>=10.0.0; sys_platform == 'linux' and python_version < '3.12'
32
36
  Requires-Dist: pyarrow<15.0.0,>=12.0.0; sys_platform == 'win32'
33
37
  Requires-Dist: pyarrow>=12.0.0; sys_platform == 'darwin'
@@ -69,7 +73,7 @@ Description-Content-Type: text/markdown
69
73
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
70
74
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
71
75
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
72
- ![Version](https://img.shields.io/badge/Version-0.4.33-orange.svg)
76
+ ![Version](https://img.shields.io/badge/Version-0.5.0-orange.svg)
73
77
  [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/zerolovesea/NextRec)
74
78
 
75
79
  中文文档 | [English Version](README_en.md)
@@ -102,6 +106,7 @@ NextRec是一个基于PyTorch的现代推荐系统框架,旨在为研究工程
102
106
  - **高效训练与评估**:内置多种优化器、学习率调度、早停、模型检查点与详细的日志管理,开箱即用。
103
107
 
104
108
  ## NextRec近期进展
109
+ - **28/01/2026** 在v0.4.39中加入了对onnx导出和加载的支持,并大大加速了数据预处理速度(最高9x加速)
105
110
  - **01/01/2026** 新年好,在v0.4.27中加入了多个多目标模型的支持:[APG](nextrec/models/multi_task/apg.py), [ESCM](nextrec/models/multi_task/escm.py), [HMoE](nextrec/models/multi_task/hmoe.py), [Cross Stitch](nextrec/models/multi_task/cross_stitch.py)
106
111
  - **28/12/2025** 在v0.4.21中加入了对SwanLab和Wandb的支持,通过model的`fit`方法进行配置:`use_swanlab=True, swanlab_kwargs={"project": "NextRec","name":"tutorial_movielens_deepfm"},`
107
112
  - **21/12/2025** 在v0.4.16中加入了对[GradNorm](/nextrec/loss/grad_norm.py)的支持,通过compile的`loss_weight='grad_norm'`进行配置
@@ -136,6 +141,7 @@ pip install nextrec # or pip install -e .
136
141
  - [example_multitask.py](/tutorials/example_multitask.py) - 电商数据集上的ESMM多任务学习训练示例
137
142
  - [movielen_match_dssm.py](/tutorials/movielen_match_dssm.py) - 基于movielen 100k数据集训练的 DSSM 召回模型示例
138
143
 
144
+ - [example_onnx.py](/tutorials/example_onnx.py) - 使用NextRec训练和导出onnx模型
139
145
  - [example_distributed_training.py](/tutorials/distributed/example_distributed_training.py) - 使用NextRec进行单机多卡训练的代码示例
140
146
 
141
147
  - [run_all_ranking_models.py](/tutorials/run_all_ranking_models.py) - 快速校验所有排序模型的可用性
@@ -254,11 +260,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
254
260
 
255
261
  预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
256
262
 
257
- > 截止当前版本0.4.33,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
263
+ > 截止当前版本0.5.0,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
258
264
 
259
265
  ## 兼容平台
260
266
 
261
- 当前最新版本为0.4.33,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
267
+ 当前最新版本为0.5.0,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
262
268
 
263
269
  | 平台 | 配置 |
264
270
  |------|------|
@@ -8,7 +8,7 @@
8
8
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
9
9
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
10
10
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
11
- ![Version](https://img.shields.io/badge/Version-0.4.33-orange.svg)
11
+ ![Version](https://img.shields.io/badge/Version-0.5.0-orange.svg)
12
12
  [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/zerolovesea/NextRec)
13
13
 
14
14
  中文文档 | [English Version](README_en.md)
@@ -41,6 +41,7 @@ NextRec是一个基于PyTorch的现代推荐系统框架,旨在为研究工程
41
41
  - **高效训练与评估**:内置多种优化器、学习率调度、早停、模型检查点与详细的日志管理,开箱即用。
42
42
 
43
43
  ## NextRec近期进展
44
+ - **28/01/2026** 在v0.4.39中加入了对onnx导出和加载的支持,并大大加速了数据预处理速度(最高9x加速)
44
45
  - **01/01/2026** 新年好,在v0.4.27中加入了多个多目标模型的支持:[APG](nextrec/models/multi_task/apg.py), [ESCM](nextrec/models/multi_task/escm.py), [HMoE](nextrec/models/multi_task/hmoe.py), [Cross Stitch](nextrec/models/multi_task/cross_stitch.py)
45
46
  - **28/12/2025** 在v0.4.21中加入了对SwanLab和Wandb的支持,通过model的`fit`方法进行配置:`use_swanlab=True, swanlab_kwargs={"project": "NextRec","name":"tutorial_movielens_deepfm"},`
46
47
  - **21/12/2025** 在v0.4.16中加入了对[GradNorm](/nextrec/loss/grad_norm.py)的支持,通过compile的`loss_weight='grad_norm'`进行配置
@@ -75,6 +76,7 @@ pip install nextrec # or pip install -e .
75
76
  - [example_multitask.py](/tutorials/example_multitask.py) - 电商数据集上的ESMM多任务学习训练示例
76
77
  - [movielen_match_dssm.py](/tutorials/movielen_match_dssm.py) - 基于movielen 100k数据集训练的 DSSM 召回模型示例
77
78
 
79
+ - [example_onnx.py](/tutorials/example_onnx.py) - 使用NextRec训练和导出onnx模型
78
80
  - [example_distributed_training.py](/tutorials/distributed/example_distributed_training.py) - 使用NextRec进行单机多卡训练的代码示例
79
81
 
80
82
  - [run_all_ranking_models.py](/tutorials/run_all_ranking_models.py) - 快速校验所有排序模型的可用性
@@ -193,11 +195,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
193
195
 
194
196
  预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
195
197
 
196
- > 截止当前版本0.4.33,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
198
+ > 截止当前版本0.5.0,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
197
199
 
198
200
  ## 兼容平台
199
201
 
200
- 当前最新版本为0.4.33,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
202
+ 当前最新版本为0.5.0,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
201
203
 
202
204
  | 平台 | 配置 |
203
205
  |------|------|
@@ -8,7 +8,7 @@
8
8
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
9
9
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
10
10
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
11
- ![Version](https://img.shields.io/badge/Version-0.4.33-orange.svg)
11
+ ![Version](https://img.shields.io/badge/Version-0.5.0-orange.svg)
12
12
  [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/zerolovesea/NextRec)
13
13
 
14
14
  English | [中文文档](README.md)
@@ -44,6 +44,7 @@ NextRec is a modern recommendation framework built on PyTorch, delivering a unif
44
44
 
45
45
  ## NextRec Progress
46
46
 
47
+ - **28/01/2026** Added support for ONNX export and loading in v0.4.39, and significantly accelerated data preprocessing speed (up to 9x speedup)
47
48
  - **01/01/2026** Happy New Year! In v0.4.27, added support for multiple multi-task models: [APG](/nextrec/models/multi_task/apg.py), [ESCM](/nextrec/models/multi_task/escm.py), [HMoE](/nextrec/models/multi_task/hmoe.py), [Cross Stitch](/nextrec/models/multi_task/cross_stitch.py)
48
49
  - **28/12/2025** Added support for SwanLab and Weights & Biases in v0.4.21, configurable via the model `fit` method: `use_swanlab=True, swanlab_kwargs={"project": "NextRec","name":"tutorial_movielens_deepfm"},`
49
50
  - **21/12/2025** Added support for [GradNorm](/nextrec/loss/grad_norm.py) in v0.4.16, configurable via `loss_weight='grad_norm'` in the compile method
@@ -79,6 +80,7 @@ See `tutorials/` for examples covering ranking, retrieval, multi-task learning,
79
80
  - [example_multitask.py](/tutorials/example_multitask.py) — ESMM multi-task learning training on e-commerce dataset
80
81
  - [movielen_match_dssm.py](/tutorials/movielen_match_dssm.py) — DSSM retrieval model training on MovieLens 100k dataset
81
82
 
83
+ - [example_onnx.py](/tutorials/example_onnx.py) — Train and export models to ONNX format with NextRec
82
84
  - [example_distributed_training.py](/tutorials/distributed/example_distributed_training.py) — Single-machine multi-GPU training with NextRec
83
85
 
84
86
  - [run_all_ranking_models.py](/tutorials/run_all_ranking_models.py) — Quickly validate availability of all ranking models
@@ -196,11 +198,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
196
198
 
197
199
  Prediction outputs are saved under `{checkpoint_path}/predictions/{name}.{save_data_format}`.
198
200
 
199
- > As of version 0.4.33, NextRec CLI supports single-machine training; distributed training features are currently under development.
201
+ > As of version 0.5.0, NextRec CLI supports single-machine training; distributed training features are currently under development.
200
202
 
201
203
  ## Platform Compatibility
202
204
 
203
- The current version is 0.4.33. All models and test code have been validated on the following platforms. If you encounter compatibility issues, please report them in the issue tracker with your system version:
205
+ The current version is 0.5.0. All models and test code have been validated on the following platforms. If you encounter compatibility issues, please report them in the issue tracker with your system version:
204
206
 
205
207
  | Platform | Configuration |
206
208
  |----------|---------------|
@@ -11,7 +11,7 @@ sys.path.insert(0, str(PROJECT_ROOT / "nextrec"))
11
11
  project = "NextRec"
12
12
  copyright = "2026, Yang Zhou"
13
13
  author = "Yang Zhou"
14
- release = "0.4.33"
14
+ release = "0.5.0"
15
15
 
16
16
  extensions = [
17
17
  "myst_parser",
@@ -28,14 +28,6 @@ nextrec.utils.data module
28
28
  :undoc-members:
29
29
  :show-inheritance:
30
30
 
31
- nextrec.utils.feature module
32
- ----------------------------
33
-
34
- .. automodule:: nextrec.utils.feature
35
- :members:
36
- :undoc-members:
37
- :show-inheritance:
38
-
39
31
  nextrec.utils.model module
40
32
  --------------------------
41
33
 
@@ -0,0 +1 @@
1
+ __version__ = "0.5.0"
@@ -1,8 +1,8 @@
1
1
  """
2
- Activation function definitions for NextRec models.
2
+ Activation function definitions.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 28/12/2025
5
+ Checkpoint: edit on 20/01/2026
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -22,26 +22,18 @@ class Dice(nn.Module):
22
22
  where p(x) = sigmoid((x - E[x]) / sqrt(Var[x] + epsilon))
23
23
  """
24
24
 
25
- def __init__(self, emb_size: int, epsilon: float = 1e-9):
25
+ def __init__(self, emb_size: int, epsilon: float = 1e-3):
26
26
  super(Dice, self).__init__()
27
- self.epsilon = epsilon
28
27
  self.alpha = nn.Parameter(torch.zeros(emb_size))
29
- self.bn = nn.BatchNorm1d(emb_size)
28
+ self.bn = nn.BatchNorm1d(emb_size, eps=epsilon, affine=False)
30
29
 
31
30
  def forward(self, x):
32
- # x shape: (batch_size, emb_size) or (batch_size, seq_len, emb_size)
33
- original_shape = x.shape
34
-
35
- if x.dim() == 3:
36
- # For 3D input (batch_size, seq_len, emb_size), reshape to 2D
37
- batch_size, seq_len, emb_size = x.shape
38
- x = x.view(-1, emb_size)
39
- x_norm = self.bn(x)
40
- p = torch.sigmoid(x_norm)
41
- output = p * x + (1 - p) * self.alpha * x
42
- if len(original_shape) == 3:
43
- output = output.view(original_shape)
44
- return output
31
+ # keep original shape for reshaping back after batch norm
32
+ orig_shape = x.shape # x: [N, L, emb_size] or [N, emb_size]
33
+ x2 = x.reshape(-1, orig_shape[-1]) # x2:[N*L, emb_size] or [N, emb_size]
34
+ x_norm = self.bn(x2)
35
+ p = torch.sigmoid(x_norm).reshape(orig_shape)
36
+ return x * (self.alpha + (1 - self.alpha) * p)
45
37
 
46
38
 
47
39
  def activation_layer(
@@ -8,7 +8,7 @@ Author: Yang Zhou, zyaztec@gmail.com
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
- from nextrec.utils.types import TaskTypeName, TrainingModeName
11
+ from nextrec.utils.types import TaskTypeName
12
12
 
13
13
 
14
14
  def assert_task(
@@ -49,24 +49,3 @@ def assert_task(
49
49
  raise ValueError(
50
50
  f"{model_name} requires task length {nums_task}, got {len(task)}."
51
51
  )
52
-
53
-
54
- def assert_training_mode(
55
- training_mode: TrainingModeName | list[TrainingModeName],
56
- nums_task: int,
57
- *,
58
- model_name: str,
59
- ) -> None:
60
- valid_modes = {"pointwise", "pairwise", "listwise"}
61
- if not isinstance(training_mode, list):
62
- raise TypeError(
63
- f"[{model_name}-init Error] training_mode must be a list with length {nums_task}."
64
- )
65
- if len(training_mode) != nums_task:
66
- raise ValueError(
67
- f"[{model_name}-init Error] training_mode list length must match number of tasks."
68
- )
69
- if any(mode not in valid_modes for mode in training_mode):
70
- raise ValueError(
71
- f"[{model_name}-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
72
- )
@@ -2,7 +2,7 @@
2
2
  Callback System for Training Process
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 27/12/2025
5
+ Checkpoint: edit on 21/01/2026
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -69,7 +69,7 @@ class Callback:
69
69
 
70
70
  class CallbackList:
71
71
  """
72
- Generates a list of callbacks
72
+ Generates a list of callbacks, used to manage and invoke multiple callbacks during training.
73
73
  """
74
74
 
75
75
  def __init__(self, callbacks: Optional[list[Callback]] = None):
@@ -8,10 +8,9 @@ Author: Yang Zhou, zyaztec@gmail.com
8
8
 
9
9
  import torch
10
10
 
11
- from typing import Literal
12
-
13
11
  from nextrec.utils.embedding import get_auto_embedding_dim
14
- from nextrec.utils.feature import to_list
12
+ from nextrec.utils.torch_utils import to_list
13
+ from nextrec.utils.types import EmbeddingInitType, SequenceCombinerType
15
14
 
16
15
 
17
16
  class BaseFeature:
@@ -29,15 +28,7 @@ class EmbeddingFeature(BaseFeature):
29
28
  embedding_name: str = "",
30
29
  embedding_dim: int | None = None,
31
30
  padding_idx: int = 0,
32
- init_type: Literal[
33
- "normal",
34
- "uniform",
35
- "xavier_uniform",
36
- "xavier_normal",
37
- "kaiming_uniform",
38
- "kaiming_normal",
39
- "orthogonal",
40
- ] = "normal",
31
+ init_type: EmbeddingInitType = "normal",
41
32
  init_params: dict | None = None,
42
33
  l1_reg: float = 0.0,
43
34
  l2_reg: float = 0.0,
@@ -73,23 +64,9 @@ class SequenceFeature(EmbeddingFeature):
73
64
  max_len: int = 50,
74
65
  embedding_name: str = "",
75
66
  embedding_dim: int | None = None,
76
- combiner: Literal[
77
- "mean",
78
- "sum",
79
- "concat",
80
- "dot_attention",
81
- "self_attention",
82
- ] = "mean",
67
+ combiner: SequenceCombinerType = "mean",
83
68
  padding_idx: int = 0,
84
- init_type: Literal[
85
- "normal",
86
- "uniform",
87
- "xavier_uniform",
88
- "xavier_normal",
89
- "kaiming_uniform",
90
- "kaiming_normal",
91
- "orthogonal",
92
- ] = "normal",
69
+ init_type: EmbeddingInitType = "normal",
93
70
  init_params: dict | None = None,
94
71
  l1_reg: float = 0.0,
95
72
  l2_reg: float = 0.0,
@@ -143,15 +120,7 @@ class SparseFeature(EmbeddingFeature):
143
120
  embedding_name: str = "",
144
121
  embedding_dim: int | None = None,
145
122
  padding_idx: int = 0,
146
- init_type: Literal[
147
- "normal",
148
- "uniform",
149
- "xavier_uniform",
150
- "xavier_normal",
151
- "kaiming_uniform",
152
- "kaiming_normal",
153
- "orthogonal",
154
- ] = "normal",
123
+ init_type: EmbeddingInitType = "normal",
155
124
  init_params: dict | None = None,
156
125
  l1_reg: float = 0.0,
157
126
  l2_reg: float = 0.0,
@@ -2,7 +2,7 @@
2
2
  Task head implementations for NextRec models.
3
3
 
4
4
  Date: create on 23/12/2025
5
- Checkpoint: edit on 27/12/2025
5
+ Checkpoint: edit on 22/01/2026
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -24,6 +24,12 @@ class TaskHead(nn.Module):
24
24
 
25
25
  This wraps PredictionLayer so models can depend on a "Head" abstraction
26
26
  without changing their existing forward signatures.
27
+
28
+ Args:
29
+ task_type: The type of task(s) this head is responsible for.
30
+ task_dims: The dimensionality of each task's output.
31
+ use_bias: Whether to include a bias term in the prediction layer.
32
+ return_logits: Whether to return raw logits or apply activation.
27
33
  """
28
34
 
29
35
  def __init__(
@@ -56,6 +62,12 @@ class RetrievalHead(nn.Module):
56
62
 
57
63
  It computes similarity for pointwise training/inference, and returns
58
64
  raw embeddings for in-batch negative sampling in pairwise/listwise modes.
65
+
66
+ Args:
67
+ similarity_metric: The metric used to compute similarity between embeddings.
68
+ temperature: Scaling factor for similarity scores.
69
+ training_mode: The training mode, which can be pointwise, pairwise, or listwise.
70
+ apply_sigmoid: Whether to apply sigmoid activation to the similarity scores in pointwise mode.
59
71
  """
60
72
 
61
73
  def __init__(
@@ -2,7 +2,7 @@
2
2
  Layer implementations used across NextRec.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 27/12/2025
5
+ Checkpoint: edit on 25/01/2026
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -20,15 +20,13 @@ import torch.nn.functional as F
20
20
  from nextrec.basic.activation import activation_layer
21
21
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
22
22
  from nextrec.utils.torch_utils import get_initializer
23
- from nextrec.utils.types import ActivationName
23
+ from nextrec.utils.types import ActivationName, TaskTypeName
24
24
 
25
25
 
26
26
  class PredictionLayer(nn.Module):
27
27
  def __init__(
28
28
  self,
29
- task_type: (
30
- Literal["binary", "regression"] | list[Literal["binary", "regression"]]
31
- ) = "binary",
29
+ task_type: TaskTypeName | list[TaskTypeName] = "binary",
32
30
  task_dims: int | list[int] | None = None,
33
31
  use_bias: bool = True,
34
32
  return_logits: bool = False,
@@ -81,10 +79,12 @@ class PredictionLayer(nn.Module):
81
79
  def forward(self, x: torch.Tensor) -> torch.Tensor:
82
80
  if x.dim() == 1:
83
81
  x = x.unsqueeze(0) # (1 * total_dim)
84
- if x.shape[-1] != self.total_dim:
85
- raise ValueError(
86
- f"[PredictionLayer Error]: Input last dimension ({x.shape[-1]}) does not match expected total dimension ({self.total_dim})."
87
- )
82
+ if not torch.onnx.is_in_onnx_export():
83
+ if x.shape[-1] != self.total_dim:
84
+ raise ValueError(
85
+ f"[PredictionLayer Error]: Input last dimension ({x.shape[-1]}) does not match expected total dimension ({self.total_dim})."
86
+ )
87
+
88
88
  logits = x if self.bias is None else x + self.bias
89
89
  outputs = []
90
90
  for task_type, (start, end) in zip(self.task_types, self.task_slices):
@@ -92,10 +92,9 @@ class PredictionLayer(nn.Module):
92
92
  if self.return_logits:
93
93
  outputs.append(task_logits)
94
94
  continue
95
- task = task_type.lower()
96
- if task == "binary":
95
+ if task_type == "binary":
97
96
  outputs.append(torch.sigmoid(task_logits))
98
- elif task == "regression":
97
+ elif task_type == "regression":
99
98
  outputs.append(task_logits)
100
99
  else:
101
100
  raise ValueError(
@@ -219,7 +218,7 @@ class EmbeddingLayer(nn.Module):
219
218
 
220
219
  elif isinstance(feature, SequenceFeature):
221
220
  seq_input = x[feature.name].long()
222
- if feature.max_len is not None and seq_input.size(1) > feature.max_len:
221
+ if feature.max_len is not None:
223
222
  seq_input = seq_input[:, -feature.max_len :]
224
223
 
225
224
  embed = self.embed_dict[feature.embedding_name]
@@ -282,10 +281,11 @@ class EmbeddingLayer(nn.Module):
282
281
  value = value.view(value.size(0), -1) # [B, input_dim]
283
282
  input_dim = feature.input_dim
284
283
  assert_input_dim = self.dense_input_dims.get(feature.name, input_dim)
285
- if value.shape[1] != assert_input_dim:
286
- raise ValueError(
287
- f"[EmbeddingLayer Error]:Dense feature '{feature.name}' expects {assert_input_dim} inputs but got {value.shape[1]}."
288
- )
284
+ if not torch.onnx.is_in_onnx_export():
285
+ if value.shape[1] != assert_input_dim:
286
+ raise ValueError(
287
+ f"[EmbeddingLayer Error]:Dense feature '{feature.name}' expects {assert_input_dim} inputs but got {value.shape[1]}."
288
+ )
289
289
  if not feature.use_projection:
290
290
  return value
291
291
  dense_layer = self.dense_transforms[feature.name]
@@ -331,29 +331,10 @@ class InputMask(nn.Module):
331
331
  feature: SequenceFeature,
332
332
  seq_tensor: torch.Tensor | None = None,
333
333
  ):
334
- if seq_tensor is not None:
335
- values = seq_tensor
336
- else:
337
- values = x[feature.name]
338
- values = values.long()
334
+ values = seq_tensor if seq_tensor is not None else x[feature.name]
335
+ values = values.long().view(values.size(0), -1)
339
336
  padding_idx = feature.padding_idx if feature.padding_idx is not None else 0
340
- mask = values != padding_idx
341
-
342
- if mask.dim() == 1:
343
- # [B] -> [B, 1, 1]
344
- mask = mask.unsqueeze(1).unsqueeze(2)
345
- elif mask.dim() == 2:
346
- # [B, L] -> [B, 1, L]
347
- mask = mask.unsqueeze(1)
348
- elif mask.dim() == 3:
349
- # [B, 1, L]
350
- # [B, L, 1] -> [B, L] -> [B, 1, L]
351
- if mask.size(1) != 1 and mask.size(2) == 1:
352
- mask = mask.squeeze(-1).unsqueeze(1)
353
- else:
354
- raise ValueError(
355
- f"InputMask only supports 1D/2D/3D tensors, got shape {values.shape}"
356
- )
337
+ mask = (values != padding_idx).unsqueeze(1)
357
338
  return mask.float()
358
339
 
359
340
 
@@ -897,30 +878,7 @@ class AttentionPoolingLayer(nn.Module):
897
878
  self,
898
879
  embedding_dim: int,
899
880
  hidden_units: list = [80, 40],
900
- activation: Literal[
901
- "dice",
902
- "relu",
903
- "relu6",
904
- "elu",
905
- "selu",
906
- "leaky_relu",
907
- "prelu",
908
- "gelu",
909
- "sigmoid",
910
- "tanh",
911
- "softplus",
912
- "softsign",
913
- "hardswish",
914
- "mish",
915
- "silu",
916
- "swish",
917
- "hardsigmoid",
918
- "tanhshrink",
919
- "softshrink",
920
- "none",
921
- "linear",
922
- "identity",
923
- ] = "sigmoid",
881
+ activation: ActivationName = "sigmoid",
924
882
  use_softmax: bool = False,
925
883
  ):
926
884
  super().__init__()
@@ -954,39 +912,22 @@ class AttentionPoolingLayer(nn.Module):
954
912
  output: [batch_size, embedding_dim] - attention pooled representation
955
913
  """
956
914
  batch_size, sequence_length, embedding_dim = keys.shape
957
- assert query.shape == (
958
- batch_size,
959
- embedding_dim,
960
- ), f"query shape {query.shape} != ({batch_size}, {embedding_dim})"
961
- if mask is None and keys_length is not None:
962
- # keys_length: (batch_size,)
963
- device = keys.device
964
- seq_range = torch.arange(sequence_length, device=device).unsqueeze(
965
- 0
966
- ) # (1, sequence_length)
967
- mask = (seq_range < keys_length.unsqueeze(1)).unsqueeze(-1).float()
968
- if mask is not None:
969
- if mask.dim() == 2:
970
- # (B, L)
971
- mask = mask.unsqueeze(-1)
972
- elif (
973
- mask.dim() == 3
974
- and mask.shape[1] == 1
975
- and mask.shape[2] == sequence_length
976
- ):
977
- # (B, 1, L) -> (B, L, 1)
978
- mask = mask.transpose(1, 2)
979
- elif (
980
- mask.dim() == 3
981
- and mask.shape[1] == sequence_length
982
- and mask.shape[2] == 1
983
- ):
984
- pass
915
+ if mask is None:
916
+ if keys_length is None:
917
+ mask = torch.ones(
918
+ (batch_size, sequence_length), device=keys.device, dtype=keys.dtype
919
+ )
985
920
  else:
921
+ device = keys.device
922
+ seq_range = torch.arange(sequence_length, device=device).unsqueeze(0)
923
+ mask = (seq_range < keys_length.unsqueeze(1)).to(keys.dtype)
924
+ else:
925
+ mask = mask.to(keys.dtype).reshape(batch_size, -1)
926
+ if mask.shape[1] != sequence_length:
986
927
  raise ValueError(
987
928
  f"[AttentionPoolingLayer Error]: Unsupported mask shape: {mask.shape}"
988
929
  )
989
- mask = mask.to(keys.dtype)
930
+ mask = mask.unsqueeze(-1)
990
931
  # Expand query to (B, L, D)
991
932
  query_expanded = query.unsqueeze(1).expand(-1, sequence_length, -1)
992
933
  # [query, key, query-key, query*key] -> (B, L, 4D)
@@ -1026,34 +967,3 @@ class RMSNorm(torch.nn.Module):
1026
967
  variance = torch.mean(x**2, dim=-1, keepdim=True)
1027
968
  x_normalized = x * torch.rsqrt(variance + self.eps)
1028
969
  return self.weight * x_normalized
1029
-
1030
-
1031
- class DomainBatchNorm(nn.Module):
1032
- """Domain-specific BatchNorm (applied per-domain with a shared interface)."""
1033
-
1034
- def __init__(self, num_features: int, num_domains: int):
1035
- super().__init__()
1036
- if num_domains < 1:
1037
- raise ValueError("num_domains must be >= 1")
1038
- self.bns = nn.ModuleList(
1039
- [nn.BatchNorm1d(num_features) for _ in range(num_domains)]
1040
- )
1041
-
1042
- def forward(self, x: torch.Tensor, domain_mask: torch.Tensor) -> torch.Tensor:
1043
- if x.dim() != 2:
1044
- raise ValueError("DomainBatchNorm expects 2D inputs [B, D].")
1045
- output = x.clone()
1046
- if domain_mask.dim() == 1:
1047
- domain_ids = domain_mask.long()
1048
- for idx, bn in enumerate(self.bns):
1049
- mask = domain_ids == idx
1050
- if mask.any():
1051
- output[mask] = bn(x[mask])
1052
- return output
1053
- if domain_mask.dim() != 2:
1054
- raise ValueError("domain_mask must be 1D indices or 2D one-hot mask.")
1055
- for idx, bn in enumerate(self.bns):
1056
- mask = domain_mask[:, idx] > 0
1057
- if mask.any():
1058
- output[mask] = bn(x[mask])
1059
- return output
@@ -2,7 +2,7 @@
2
2
  NextRec Basic Loggers
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 01/01/2026
5
+ Checkpoint: edit on 22/01/2026
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -99,7 +99,8 @@ def format_kv(label: str, value: Any, width: int = 34, indent: int = 0) -> str:
99
99
 
100
100
 
101
101
  def setup_logger(session_id: str | os.PathLike | None = None):
102
- """Set up a logger that logs to both console and a file with ANSI formatting.
102
+ """
103
+ Set up a logger that logs to both console and a file with ANSI formatting.
103
104
  Only console output has colors; file output is stripped of ANSI codes.
104
105
 
105
106
  Logs are stored under ``log/<experiment_id>/logs`` by default. A stable