nextrec 0.4.22__tar.gz → 0.4.24__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. {nextrec-0.4.22 → nextrec-0.4.24}/PKG-INFO +8 -6
  2. {nextrec-0.4.22 → nextrec-0.4.24}/README.md +7 -5
  3. {nextrec-0.4.22 → nextrec-0.4.24}/README_en.md +5 -5
  4. {nextrec-0.4.22 → nextrec-0.4.24}/docs/en/Getting started guide.md +2 -1
  5. {nextrec-0.4.22 → nextrec-0.4.24}/docs/rtd/conf.py +1 -1
  6. {nextrec-0.4.22 → nextrec-0.4.24}/docs/rtd/index.md +1 -0
  7. {nextrec-0.4.22 → nextrec-0.4.24}/docs/zh//345/277/253/351/200/237/344/270/212/346/211/213.md +2 -1
  8. nextrec-0.4.24/nextrec/__version__.py +1 -0
  9. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/basic/layers.py +96 -46
  10. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/basic/metrics.py +128 -114
  11. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/basic/model.py +94 -91
  12. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/basic/summary.py +36 -2
  13. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/data/dataloader.py +2 -0
  14. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/data/preprocessor.py +137 -5
  15. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/loss/listwise.py +19 -6
  16. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/loss/pairwise.py +6 -4
  17. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/loss/pointwise.py +8 -6
  18. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/multi_task/esmm.py +5 -28
  19. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/multi_task/mmoe.py +6 -28
  20. nextrec-0.4.24/nextrec/models/multi_task/pepnet.py +335 -0
  21. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/multi_task/ple.py +21 -40
  22. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/multi_task/poso.py +17 -39
  23. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/multi_task/share_bottom.py +5 -28
  24. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/ranking/afm.py +3 -27
  25. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/ranking/autoint.py +5 -38
  26. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/ranking/dcn.py +1 -26
  27. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/ranking/dcn_v2.py +6 -34
  28. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/ranking/deepfm.py +2 -29
  29. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/ranking/dien.py +2 -28
  30. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/ranking/din.py +2 -27
  31. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/ranking/eulernet.py +3 -30
  32. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/ranking/ffm.py +0 -26
  33. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/ranking/fibinet.py +8 -32
  34. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/ranking/fm.py +0 -29
  35. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/ranking/lr.py +0 -30
  36. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/ranking/masknet.py +4 -30
  37. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/ranking/pnn.py +4 -28
  38. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/ranking/widedeep.py +0 -32
  39. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/ranking/xdeepfm.py +0 -30
  40. nextrec-0.4.24/nextrec/models/representation/mf.py +0 -0
  41. nextrec-0.4.24/nextrec/models/representation/s3rec.py +0 -0
  42. nextrec-0.4.24/nextrec/models/retrieval/__init__.py +0 -0
  43. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/retrieval/dssm.py +4 -28
  44. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/retrieval/dssm_v2.py +4 -28
  45. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/retrieval/mind.py +2 -22
  46. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/retrieval/sdm.py +4 -24
  47. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/retrieval/youtube_dnn.py +4 -25
  48. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/sequential/hstu.py +0 -18
  49. nextrec-0.4.24/nextrec/models/sequential/sasrec.py +0 -0
  50. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/utils/model.py +91 -4
  51. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/utils/types.py +35 -0
  52. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec_cli_preset/NextRec-CLI.md +7 -7
  53. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec_cli_preset/NextRec-CLI_zh.md +7 -7
  54. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec_cli_preset/model_configs/dcn.yaml +1 -1
  55. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec_cli_preset/model_configs/deepfm.yaml +1 -1
  56. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec_cli_preset/model_configs/din.yaml +1 -1
  57. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec_cli_preset/model_configs/esmm.yaml +2 -2
  58. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec_cli_preset/model_configs/fibinet.yaml +1 -1
  59. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec_cli_preset/model_configs/masknet.yaml +1 -1
  60. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec_cli_preset/model_configs/mmoe.yaml +3 -3
  61. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec_cli_preset/model_configs/ple.yaml +4 -4
  62. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec_cli_preset/model_configs/pnn.yaml +1 -1
  63. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec_cli_preset/model_configs/poso.yaml +2 -3
  64. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec_cli_preset/model_configs/share_bottom.yaml +3 -3
  65. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec_cli_preset/model_configs/widedeep.yaml +1 -1
  66. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec_cli_preset/model_configs/xdeepfm.yaml +1 -1
  67. {nextrec-0.4.22 → nextrec-0.4.24}/pyproject.toml +1 -1
  68. {nextrec-0.4.22 → nextrec-0.4.24}/test/test_layers.py +9 -9
  69. {nextrec-0.4.22 → nextrec-0.4.24}/test/test_multitask_models.py +172 -47
  70. {nextrec-0.4.22 → nextrec-0.4.24}/test/test_ranking_models.py +33 -28
  71. {nextrec-0.4.22 → nextrec-0.4.24}/tutorials/example_multitask.py +1 -8
  72. {nextrec-0.4.22 → nextrec-0.4.24}/tutorials/example_ranking_din.py +3 -5
  73. {nextrec-0.4.22 → nextrec-0.4.24}/tutorials/notebooks/en/Hands on nextrec.ipynb +1 -1
  74. {nextrec-0.4.22 → nextrec-0.4.24}/tutorials/notebooks/zh//345/277/253/351/200/237/345/205/245/351/227/250nextrec.ipynb +1 -1
  75. {nextrec-0.4.22 → nextrec-0.4.24}/tutorials/run_all_multitask_models.py +15 -1
  76. {nextrec-0.4.22 → nextrec-0.4.24}/tutorials/run_all_ranking_models.py +1 -1
  77. nextrec-0.4.22/nextrec/__version__.py +0 -1
  78. {nextrec-0.4.22 → nextrec-0.4.24}/.github/workflows/publish.yml +0 -0
  79. {nextrec-0.4.22 → nextrec-0.4.24}/.github/workflows/tests.yml +0 -0
  80. {nextrec-0.4.22 → nextrec-0.4.24}/.gitignore +0 -0
  81. {nextrec-0.4.22 → nextrec-0.4.24}/.readthedocs.yaml +0 -0
  82. {nextrec-0.4.22 → nextrec-0.4.24}/CODE_OF_CONDUCT.md +0 -0
  83. {nextrec-0.4.22 → nextrec-0.4.24}/CONTRIBUTING.md +0 -0
  84. {nextrec-0.4.22 → nextrec-0.4.24}/LICENSE +0 -0
  85. {nextrec-0.4.22 → nextrec-0.4.24}/MANIFEST.in +0 -0
  86. {nextrec-0.4.22 → nextrec-0.4.24}/assets/Feature Configuration.png +0 -0
  87. {nextrec-0.4.22 → nextrec-0.4.24}/assets/Model Parameters.png +0 -0
  88. {nextrec-0.4.22 → nextrec-0.4.24}/assets/Training Configuration.png +0 -0
  89. {nextrec-0.4.22 → nextrec-0.4.24}/assets/Training logs.png +0 -0
  90. {nextrec-0.4.22 → nextrec-0.4.24}/assets/logo.png +0 -0
  91. {nextrec-0.4.22 → nextrec-0.4.24}/assets/mmoe_tutorial.png +0 -0
  92. {nextrec-0.4.22 → nextrec-0.4.24}/assets/nextrec_diagram.png +0 -0
  93. {nextrec-0.4.22 → nextrec-0.4.24}/assets/test data.png +0 -0
  94. {nextrec-0.4.22 → nextrec-0.4.24}/dataset/ctcvr_task.csv +0 -0
  95. {nextrec-0.4.22 → nextrec-0.4.24}/dataset/ecommerce_task.csv +0 -0
  96. {nextrec-0.4.22 → nextrec-0.4.24}/dataset/match_task.csv +0 -0
  97. {nextrec-0.4.22 → nextrec-0.4.24}/dataset/movielens_100k.csv +0 -0
  98. {nextrec-0.4.22 → nextrec-0.4.24}/dataset/multitask_task.csv +0 -0
  99. {nextrec-0.4.22 → nextrec-0.4.24}/dataset/ranking_task.csv +0 -0
  100. {nextrec-0.4.22 → nextrec-0.4.24}/docs/rtd/Makefile +0 -0
  101. {nextrec-0.4.22 → nextrec-0.4.24}/docs/rtd/make.bat +0 -0
  102. {nextrec-0.4.22 → nextrec-0.4.24}/docs/rtd/modules.rst +0 -0
  103. {nextrec-0.4.22 → nextrec-0.4.24}/docs/rtd/nextrec.basic.rst +0 -0
  104. {nextrec-0.4.22 → nextrec-0.4.24}/docs/rtd/nextrec.data.rst +0 -0
  105. {nextrec-0.4.22 → nextrec-0.4.24}/docs/rtd/nextrec.loss.rst +0 -0
  106. {nextrec-0.4.22 → nextrec-0.4.24}/docs/rtd/nextrec.rst +0 -0
  107. {nextrec-0.4.22 → nextrec-0.4.24}/docs/rtd/nextrec.utils.rst +0 -0
  108. {nextrec-0.4.22 → nextrec-0.4.24}/docs/rtd/requirements.txt +0 -0
  109. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/__init__.py +0 -0
  110. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/basic/__init__.py +0 -0
  111. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/basic/activation.py +0 -0
  112. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/basic/callback.py +0 -0
  113. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/basic/features.py +0 -0
  114. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/basic/heads.py +0 -0
  115. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/basic/loggers.py +0 -0
  116. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/basic/session.py +0 -0
  117. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/cli.py +0 -0
  118. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/data/__init__.py +0 -0
  119. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/data/batch_utils.py +0 -0
  120. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/data/data_processing.py +0 -0
  121. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/data/data_utils.py +0 -0
  122. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/loss/__init__.py +0 -0
  123. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/loss/grad_norm.py +0 -0
  124. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/generative/__init__.py +0 -0
  125. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/generative/tiger.py +0 -0
  126. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/multi_task/__init__.py +0 -0
  127. /nextrec-0.4.22/nextrec/models/ranking/__init__.py → /nextrec-0.4.24/nextrec/models/multi_task/aitm.py +0 -0
  128. /nextrec-0.4.22/nextrec/models/representation/autorec.py → /nextrec-0.4.24/nextrec/models/multi_task/apg.py +0 -0
  129. /nextrec-0.4.22/nextrec/models/representation/bpr.py → /nextrec-0.4.24/nextrec/models/multi_task/cross_stitch.py +0 -0
  130. /nextrec-0.4.22/nextrec/models/representation/cl4srec.py → /nextrec-0.4.24/nextrec/models/multi_task/snr_trans.py +0 -0
  131. {nextrec-0.4.22/nextrec/models/retrieval → nextrec-0.4.24/nextrec/models/ranking}/__init__.py +0 -0
  132. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/representation/__init__.py +0 -0
  133. /nextrec-0.4.22/nextrec/models/representation/lightgcn.py → /nextrec-0.4.24/nextrec/models/representation/autorec.py +0 -0
  134. /nextrec-0.4.22/nextrec/models/representation/mf.py → /nextrec-0.4.24/nextrec/models/representation/bpr.py +0 -0
  135. /nextrec-0.4.22/nextrec/models/representation/s3rec.py → /nextrec-0.4.24/nextrec/models/representation/cl4srec.py +0 -0
  136. /nextrec-0.4.22/nextrec/models/sequential/sasrec.py → /nextrec-0.4.24/nextrec/models/representation/lightgcn.py +0 -0
  137. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/models/representation/rqvae.py +0 -0
  138. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/utils/__init__.py +0 -0
  139. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/utils/config.py +0 -0
  140. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/utils/console.py +0 -0
  141. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/utils/data.py +0 -0
  142. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/utils/embedding.py +0 -0
  143. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/utils/feature.py +0 -0
  144. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/utils/loss.py +0 -0
  145. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec/utils/torch_utils.py +0 -0
  146. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec_cli_preset/feature_config.yaml +0 -0
  147. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec_cli_preset/model_configs/afm.yaml +0 -0
  148. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec_cli_preset/model_configs/autoint.yaml +0 -0
  149. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec_cli_preset/model_configs/fm.yaml +0 -0
  150. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec_cli_preset/predict_config.yaml +0 -0
  151. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec_cli_preset/predict_config_template.yaml +0 -0
  152. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec_cli_preset/train_config.yaml +0 -0
  153. {nextrec-0.4.22 → nextrec-0.4.24}/nextrec_cli_preset/train_config_template.yaml +0 -0
  154. {nextrec-0.4.22 → nextrec-0.4.24}/pytest.ini +0 -0
  155. {nextrec-0.4.22 → nextrec-0.4.24}/requirements.txt +0 -0
  156. {nextrec-0.4.22 → nextrec-0.4.24}/scripts/format_code.py +0 -0
  157. {nextrec-0.4.22 → nextrec-0.4.24}/test/__init__.py +0 -0
  158. {nextrec-0.4.22 → nextrec-0.4.24}/test/conftest.py +0 -0
  159. {nextrec-0.4.22 → nextrec-0.4.24}/test/helpers.py +0 -0
  160. {nextrec-0.4.22 → nextrec-0.4.24}/test/run_tests.py +0 -0
  161. {nextrec-0.4.22 → nextrec-0.4.24}/test/test_base_model_regularization.py +0 -0
  162. {nextrec-0.4.22 → nextrec-0.4.24}/test/test_generative_models.py +0 -0
  163. {nextrec-0.4.22 → nextrec-0.4.24}/test/test_losses.py +0 -0
  164. {nextrec-0.4.22 → nextrec-0.4.24}/test/test_match_models.py +0 -0
  165. {nextrec-0.4.22 → nextrec-0.4.24}/test/test_preprocessor.py +0 -0
  166. {nextrec-0.4.22 → nextrec-0.4.24}/test/test_utils_console.py +0 -0
  167. {nextrec-0.4.22 → nextrec-0.4.24}/test/test_utils_data.py +0 -0
  168. {nextrec-0.4.22 → nextrec-0.4.24}/test/test_utils_embedding.py +0 -0
  169. {nextrec-0.4.22 → nextrec-0.4.24}/test_requirements.txt +0 -0
  170. {nextrec-0.4.22 → nextrec-0.4.24}/tutorials/distributed/example_distributed_training.py +0 -0
  171. {nextrec-0.4.22 → nextrec-0.4.24}/tutorials/distributed/example_distributed_training_large_dataset.py +0 -0
  172. {nextrec-0.4.22 → nextrec-0.4.24}/tutorials/example_match.py +0 -0
  173. {nextrec-0.4.22 → nextrec-0.4.24}/tutorials/movielen_match_dssm.py +0 -0
  174. {nextrec-0.4.22 → nextrec-0.4.24}/tutorials/movielen_ranking_deepfm.py +0 -0
  175. {nextrec-0.4.22 → nextrec-0.4.24}/tutorials/notebooks/en/Build semantic ID with RQ-VAE.ipynb +0 -0
  176. {nextrec-0.4.22 → nextrec-0.4.24}/tutorials/notebooks/en/Hands on dataprocessor.ipynb +0 -0
  177. {nextrec-0.4.22 → nextrec-0.4.24}/tutorials/notebooks/zh//344/275/277/347/224/250RQ-VAE/346/236/204/345/273/272/350/257/255/344/271/211ID.ipynb" +0 -0
  178. {nextrec-0.4.22 → nextrec-0.4.24}/tutorials/notebooks/zh//345/246/202/344/275/225/344/275/277/347/224/250DataProcessor/350/277/233/350/241/214/351/242/204/345/244/204/347/220/206.ipynb" +0 -0
  179. {nextrec-0.4.22 → nextrec-0.4.24}/tutorials/run_all_match_models.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nextrec
3
- Version: 0.4.22
3
+ Version: 0.4.24
4
4
  Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
5
5
  Project-URL: Homepage, https://github.com/zerolovesea/NextRec
6
6
  Project-URL: Repository, https://github.com/zerolovesea/NextRec
@@ -69,7 +69,7 @@ Description-Content-Type: text/markdown
69
69
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
70
70
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
71
71
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
72
- ![Version](https://img.shields.io/badge/Version-0.4.22-orange.svg)
72
+ ![Version](https://img.shields.io/badge/Version-0.4.24-orange.svg)
73
73
  [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/zerolovesea/NextRec)
74
74
 
75
75
  中文文档 | [English Version](README_en.md)
@@ -182,7 +182,7 @@ sequence_features = [
182
182
  SequenceFeature(name='sequence_1', vocab_size=int(df['sequence_1'].apply(lambda x: max(x)).max() + 1), embedding_dim=16, padding_idx=0, embedding_name='sparse_0_emb'),]
183
183
 
184
184
  mlp_params = {
185
- "dims": [256, 128, 64],
185
+ "hidden_dims": [256, 128, 64],
186
186
  "activation": "relu",
187
187
  "dropout": 0.3,
188
188
  }
@@ -191,6 +191,8 @@ model = DIN(
191
191
  dense_features=dense_features,
192
192
  sparse_features=sparse_features,
193
193
  sequence_features=sequence_features,
194
+ behavior_feature_name="sequence_0",
195
+ candidate_feature_name="item_id",
194
196
  mlp_params=mlp_params,
195
197
  attention_hidden_units=[80, 40],
196
198
  attention_activation='sigmoid',
@@ -204,7 +206,7 @@ model = DIN(
204
206
  session_id="din_tutorial", # 实验id,用于存放训练日志
205
207
  )
206
208
 
207
- # 编译模型,设置优化器和损失函数
209
+ # 编译模型,优化器/损失/学习率调度器统一在 compile 中设置
208
210
  model.compile(
209
211
  optimizer = "adam",
210
212
  optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5},
@@ -247,11 +249,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
247
249
 
248
250
  预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
249
251
 
250
- > 截止当前版本0.4.22,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
252
+ > 截止当前版本0.4.24,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
251
253
 
252
254
  ## 兼容平台
253
255
 
254
- 当前最新版本为0.4.22,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
256
+ 当前最新版本为0.4.24,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
255
257
 
256
258
  | 平台 | 配置 |
257
259
  |------|------|
@@ -8,7 +8,7 @@
8
8
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
9
9
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
10
10
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
11
- ![Version](https://img.shields.io/badge/Version-0.4.22-orange.svg)
11
+ ![Version](https://img.shields.io/badge/Version-0.4.24-orange.svg)
12
12
  [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/zerolovesea/NextRec)
13
13
 
14
14
  中文文档 | [English Version](README_en.md)
@@ -121,7 +121,7 @@ sequence_features = [
121
121
  SequenceFeature(name='sequence_1', vocab_size=int(df['sequence_1'].apply(lambda x: max(x)).max() + 1), embedding_dim=16, padding_idx=0, embedding_name='sparse_0_emb'),]
122
122
 
123
123
  mlp_params = {
124
- "dims": [256, 128, 64],
124
+ "hidden_dims": [256, 128, 64],
125
125
  "activation": "relu",
126
126
  "dropout": 0.3,
127
127
  }
@@ -130,6 +130,8 @@ model = DIN(
130
130
  dense_features=dense_features,
131
131
  sparse_features=sparse_features,
132
132
  sequence_features=sequence_features,
133
+ behavior_feature_name="sequence_0",
134
+ candidate_feature_name="item_id",
133
135
  mlp_params=mlp_params,
134
136
  attention_hidden_units=[80, 40],
135
137
  attention_activation='sigmoid',
@@ -143,7 +145,7 @@ model = DIN(
143
145
  session_id="din_tutorial", # 实验id,用于存放训练日志
144
146
  )
145
147
 
146
- # 编译模型,设置优化器和损失函数
148
+ # 编译模型,优化器/损失/学习率调度器统一在 compile 中设置
147
149
  model.compile(
148
150
  optimizer = "adam",
149
151
  optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5},
@@ -186,11 +188,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
186
188
 
187
189
  预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
188
190
 
189
- > 截止当前版本0.4.22,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
191
+ > 截止当前版本0.4.24,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
190
192
 
191
193
  ## 兼容平台
192
194
 
193
- 当前最新版本为0.4.22,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
195
+ 当前最新版本为0.4.24,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
194
196
 
195
197
  | 平台 | 配置 |
196
198
  |------|------|
@@ -8,7 +8,7 @@
8
8
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
9
9
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
10
10
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
11
- ![Version](https://img.shields.io/badge/Version-0.4.22-orange.svg)
11
+ ![Version](https://img.shields.io/badge/Version-0.4.24-orange.svg)
12
12
  [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/zerolovesea/NextRec)
13
13
 
14
14
  English | [中文文档](README.md)
@@ -126,7 +126,7 @@ sequence_features = [
126
126
  SequenceFeature(name='sequence_1', vocab_size=int(df['sequence_1'].apply(lambda x: max(x)).max() + 1), embedding_dim=16, padding_idx=0, embedding_name='sparse_0_emb'),]
127
127
 
128
128
  mlp_params = {
129
- "dims": [256, 128, 64],
129
+ "hidden_dims": [256, 128, 64],
130
130
  "activation": "relu",
131
131
  "dropout": 0.3,
132
132
  }
@@ -148,7 +148,7 @@ model = DIN(
148
148
  session_id="din_tutorial", # experiment id for logs
149
149
  )
150
150
 
151
- # Compile model with optimizer and loss
151
+ # Compile model; configure optimizer/loss/scheduler via compile()
152
152
  model.compile(
153
153
  optimizer = "adam",
154
154
  optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5},
@@ -191,11 +191,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
191
191
 
192
192
  Prediction outputs are saved under `{checkpoint_path}/predictions/{name}.{save_data_format}`.
193
193
 
194
- > As of version 0.4.22, NextRec CLI supports single-machine training; distributed training features are currently under development.
194
+ > As of version 0.4.24, NextRec CLI supports single-machine training; distributed training features are currently under development.
195
195
 
196
196
  ## Platform Compatibility
197
197
 
198
- The current version is 0.4.22. All models and test code have been validated on the following platforms. If you encounter compatibility issues, please report them in the issue tracker with your system version:
198
+ The current version is 0.4.24. All models and test code have been validated on the following platforms. If you encounter compatibility issues, please report them in the issue tracker with your system version:
199
199
 
200
200
  | Platform | Configuration |
201
201
  |----------|---------------|
@@ -49,12 +49,13 @@ train_df, valid_df = train_test_split(df, test_size=0.2, random_state=2024)
49
49
  model = DeepFM(
50
50
  dense_features=dense_features,
51
51
  sparse_features=sparse_features,
52
- mlp_params={"dims": [256, 128], "activation": "relu", "dropout": 0.2},
52
+ mlp_params={"hidden_dims": [256, 128], "activation": "relu", "dropout": 0.2},
53
53
  target="label",
54
54
  device="cpu",
55
55
  session_id="movielens_deepfm", # manages logs and checkpoints
56
56
  )
57
57
 
58
+ # Optimizer/loss/scheduler are configured via compile()
58
59
  model.compile(
59
60
  optimizer="adam",
60
61
  optimizer_params={"lr": 1e-3, "weight_decay": 1e-5},
@@ -11,7 +11,7 @@ sys.path.insert(0, str(PROJECT_ROOT / "nextrec"))
11
11
  project = "NextRec"
12
12
  copyright = "2025, Yang Zhou"
13
13
  author = "Yang Zhou"
14
- release = "0.4.22"
14
+ release = "0.4.24"
15
15
 
16
16
  extensions = [
17
17
  "myst_parser",
@@ -57,6 +57,7 @@ model = DeepFM(
57
57
  session_id="deepfm_demo",
58
58
  )
59
59
 
60
+ # Configure optimizer/loss/scheduler via compile()
60
61
  model.compile(
61
62
  optimizer="adam",
62
63
  optimizer_params={"lr": 1e-3, "weight_decay": 1e-5},
@@ -49,12 +49,13 @@ train_df, valid_df = train_test_split(df, test_size=0.2, random_state=2024)
49
49
  model = DeepFM(
50
50
  dense_features=dense_features,
51
51
  sparse_features=sparse_features,
52
- mlp_params={"dims": [256, 128], "activation": "relu", "dropout": 0.2},
52
+ mlp_params={"hidden_dims": [256, 128], "activation": "relu", "dropout": 0.2},
53
53
  target="label",
54
54
  device="cpu",
55
55
  session_id="movielens_deepfm", # 管理实验日志与检查点
56
56
  )
57
57
 
58
+ # 优化器/损失/学习率调度器统一在 compile 中设置
58
59
  model.compile(
59
60
  optimizer="adam",
60
61
  optimizer_params={"lr": 1e-3, "weight_decay": 1e-5},
@@ -0,0 +1 @@
1
+ __version__ = "0.4.24"
@@ -20,6 +20,7 @@ import torch.nn.functional as F
20
20
  from nextrec.basic.activation import activation_layer
21
21
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
22
22
  from nextrec.utils.torch_utils import get_initializer
23
+ from nextrec.utils.types import ActivationName
23
24
 
24
25
 
25
26
  class PredictionLayer(nn.Module):
@@ -590,71 +591,48 @@ class MLP(nn.Module):
590
591
  def __init__(
591
592
  self,
592
593
  input_dim: int,
593
- output_layer: bool = True,
594
- dims: list[int] | None = None,
594
+ hidden_dims: list[int] | None = None,
595
+ output_dim: int | None = 1,
595
596
  dropout: float = 0.0,
596
- activation: Literal[
597
- "dice",
598
- "relu",
599
- "relu6",
600
- "elu",
601
- "selu",
602
- "leaky_relu",
603
- "prelu",
604
- "gelu",
605
- "sigmoid",
606
- "tanh",
607
- "softplus",
608
- "softsign",
609
- "hardswish",
610
- "mish",
611
- "silu",
612
- "swish",
613
- "hardsigmoid",
614
- "tanhshrink",
615
- "softshrink",
616
- "none",
617
- "linear",
618
- "identity",
619
- ] = "relu",
620
- use_norm: bool = True,
621
- norm_type: Literal["batch_norm", "layer_norm"] = "layer_norm",
597
+ activation: ActivationName = "relu",
598
+ norm_type: Literal["batch_norm", "layer_norm", "none"] = "none",
599
+ output_activation: ActivationName = "none",
622
600
  ):
623
601
  """
624
602
  Multi-Layer Perceptron (MLP) module.
625
603
 
626
604
  Args:
627
605
  input_dim: Dimension of the input features.
628
- output_layer: Whether to include the final output layer. If False, the MLP will output the last hidden layer, else it will output a single value.
629
- dims: List of hidden layer dimensions. If None, no hidden layers are added.
606
+ output_dim: Output dimension of the final layer. If None, no output layer is added.
607
+ hidden_dims: List of hidden layer dimensions. If None, no hidden layers are added.
630
608
  dropout: Dropout rate between layers.
631
609
  activation: Activation function to use between layers.
632
- use_norm: Whether to use normalization layers.
633
- norm_type: Type of normalization to use ("batch_norm" or "layer_norm").
610
+ norm_type: Type of normalization to use ("batch_norm", "layer_norm", or "none").
611
+ output_activation: Activation function applied after the output layer.
634
612
  """
635
613
  super().__init__()
636
- if dims is None:
637
- dims = []
614
+ hidden_dims = hidden_dims or []
638
615
  layers = []
639
616
  current_dim = input_dim
640
- for i_dim in dims:
617
+ for i_dim in hidden_dims:
641
618
  layers.append(nn.Linear(current_dim, i_dim))
642
- if use_norm:
643
- if norm_type == "batch_norm":
644
- # **IMPORTANT** be careful when using BatchNorm1d in distributed training, nextrec does not support sync batch norm now
645
- layers.append(nn.BatchNorm1d(i_dim))
646
- elif norm_type == "layer_norm":
647
- layers.append(nn.LayerNorm(i_dim))
648
- else:
649
- raise ValueError(f"Unsupported norm_type: {norm_type}")
619
+ if norm_type == "batch_norm":
620
+ # **IMPORTANT** be careful when using BatchNorm1d in distributed training, nextrec does not support sync batch norm now
621
+ layers.append(nn.BatchNorm1d(i_dim))
622
+ elif norm_type == "layer_norm":
623
+ layers.append(nn.LayerNorm(i_dim))
624
+ elif norm_type != "none":
625
+ raise ValueError(f"Unsupported norm_type: {norm_type}")
650
626
 
651
627
  layers.append(activation_layer(activation))
652
628
  layers.append(nn.Dropout(p=dropout))
653
629
  current_dim = i_dim
654
630
  # output layer
655
- if output_layer:
656
- layers.append(nn.Linear(current_dim, 1))
657
- self.output_dim = 1
631
+ if output_dim is not None:
632
+ layers.append(nn.Linear(current_dim, output_dim))
633
+ if output_activation != "none":
634
+ layers.append(activation_layer(output_activation))
635
+ self.output_dim = output_dim
658
636
  else:
659
637
  self.output_dim = current_dim
660
638
  self.mlp = nn.Sequential(*layers)
@@ -663,6 +641,47 @@ class MLP(nn.Module):
663
641
  return self.mlp(x)
664
642
 
665
643
 
644
+ class GateMLP(nn.Module):
645
+ """
646
+ Lightweight gate network: sigmoid MLP scaled by a constant factor.
647
+
648
+ Args:
649
+ input_dim: Dimension of the input features.
650
+ hidden_dim: Dimension of the hidden layer. If None, defaults to output_dim.
651
+ output_dim: Output dimension of the gate.
652
+ activation: Activation function to use in the hidden layer.
653
+ dropout: Dropout rate between layers.
654
+ use_bn: Whether to use batch normalization.
655
+ scale_factor: Scaling factor applied to the sigmoid output.
656
+ """
657
+
658
+ def __init__(
659
+ self,
660
+ input_dim: int,
661
+ hidden_dim: int | None,
662
+ output_dim: int,
663
+ activation: ActivationName = "relu",
664
+ dropout: float = 0.0,
665
+ use_bn: bool = False,
666
+ scale_factor: float = 2.0,
667
+ ) -> None:
668
+ super().__init__()
669
+ hidden_dim = output_dim if hidden_dim is None else hidden_dim
670
+ self.gate = MLP(
671
+ input_dim=input_dim,
672
+ hidden_dims=[hidden_dim],
673
+ output_dim=output_dim,
674
+ activation=activation,
675
+ dropout=dropout,
676
+ norm_type="batch_norm" if use_bn else "none",
677
+ output_activation="sigmoid",
678
+ )
679
+ self.scale_factor = scale_factor
680
+
681
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
682
+ return self.gate(inputs) * self.scale_factor
683
+
684
+
666
685
  class FM(nn.Module):
667
686
  def __init__(self, reduce_sum: bool = True):
668
687
  super().__init__()
@@ -1007,3 +1026,34 @@ class RMSNorm(torch.nn.Module):
1007
1026
  variance = torch.mean(x**2, dim=-1, keepdim=True)
1008
1027
  x_normalized = x * torch.rsqrt(variance + self.eps)
1009
1028
  return self.weight * x_normalized
1029
+
1030
+
1031
+ class DomainBatchNorm(nn.Module):
1032
+ """Domain-specific BatchNorm (applied per-domain with a shared interface)."""
1033
+
1034
+ def __init__(self, num_features: int, num_domains: int):
1035
+ super().__init__()
1036
+ if num_domains < 1:
1037
+ raise ValueError("num_domains must be >= 1")
1038
+ self.bns = nn.ModuleList(
1039
+ [nn.BatchNorm1d(num_features) for _ in range(num_domains)]
1040
+ )
1041
+
1042
+ def forward(self, x: torch.Tensor, domain_mask: torch.Tensor) -> torch.Tensor:
1043
+ if x.dim() != 2:
1044
+ raise ValueError("DomainBatchNorm expects 2D inputs [B, D].")
1045
+ output = x.clone()
1046
+ if domain_mask.dim() == 1:
1047
+ domain_ids = domain_mask.long()
1048
+ for idx, bn in enumerate(self.bns):
1049
+ mask = domain_ids == idx
1050
+ if mask.any():
1051
+ output[mask] = bn(x[mask])
1052
+ return output
1053
+ if domain_mask.dim() != 2:
1054
+ raise ValueError("domain_mask must be 1D indices or 2D one-hot mask.")
1055
+ for idx, bn in enumerate(self.bns):
1056
+ mask = domain_mask[:, idx] > 0
1057
+ if mask.any():
1058
+ output[mask] = bn(x[mask])
1059
+ return output