nextrec 0.4.15__tar.gz → 0.4.17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. {nextrec-0.4.15 → nextrec-0.4.17}/PKG-INFO +5 -5
  2. {nextrec-0.4.15 → nextrec-0.4.17}/README.md +4 -4
  3. {nextrec-0.4.15 → nextrec-0.4.17}/README_en.md +4 -4
  4. {nextrec-0.4.15 → nextrec-0.4.17}/docs/rtd/conf.py +1 -1
  5. nextrec-0.4.17/nextrec/__version__.py +1 -0
  6. nextrec-0.4.17/nextrec/basic/heads.py +101 -0
  7. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/basic/metrics.py +2 -0
  8. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/basic/model.py +18 -14
  9. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/multi_task/esmm.py +4 -3
  10. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/multi_task/mmoe.py +4 -3
  11. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/multi_task/ple.py +4 -3
  12. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/multi_task/poso.py +4 -3
  13. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/multi_task/share_bottom.py +4 -3
  14. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/ranking/afm.py +4 -3
  15. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/ranking/autoint.py +4 -3
  16. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/ranking/dcn.py +4 -3
  17. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/ranking/dcn_v2.py +4 -3
  18. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/ranking/deepfm.py +4 -3
  19. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/ranking/dien.py +2 -2
  20. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/ranking/din.py +2 -2
  21. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/ranking/eulernet.py +4 -3
  22. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/ranking/ffm.py +4 -3
  23. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/ranking/fibinet.py +2 -2
  24. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/ranking/fm.py +4 -3
  25. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/ranking/lr.py +4 -3
  26. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/ranking/masknet.py +4 -3
  27. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/ranking/pnn.py +4 -3
  28. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/ranking/widedeep.py +4 -3
  29. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/ranking/xdeepfm.py +4 -3
  30. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/utils/config.py +12 -3
  31. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/utils/console.py +83 -28
  32. {nextrec-0.4.15 → nextrec-0.4.17}/pyproject.toml +1 -1
  33. nextrec-0.4.15/nextrec/__version__.py +0 -1
  34. {nextrec-0.4.15 → nextrec-0.4.17}/.github/workflows/publish.yml +0 -0
  35. {nextrec-0.4.15 → nextrec-0.4.17}/.github/workflows/tests.yml +0 -0
  36. {nextrec-0.4.15 → nextrec-0.4.17}/.gitignore +0 -0
  37. {nextrec-0.4.15 → nextrec-0.4.17}/.readthedocs.yaml +0 -0
  38. {nextrec-0.4.15 → nextrec-0.4.17}/CODE_OF_CONDUCT.md +0 -0
  39. {nextrec-0.4.15 → nextrec-0.4.17}/CONTRIBUTING.md +0 -0
  40. {nextrec-0.4.15 → nextrec-0.4.17}/LICENSE +0 -0
  41. {nextrec-0.4.15 → nextrec-0.4.17}/MANIFEST.in +0 -0
  42. {nextrec-0.4.15 → nextrec-0.4.17}/assets/Feature Configuration.png +0 -0
  43. {nextrec-0.4.15 → nextrec-0.4.17}/assets/Model Parameters.png +0 -0
  44. {nextrec-0.4.15 → nextrec-0.4.17}/assets/Training Configuration.png +0 -0
  45. {nextrec-0.4.15 → nextrec-0.4.17}/assets/Training logs.png +0 -0
  46. {nextrec-0.4.15 → nextrec-0.4.17}/assets/logo.png +0 -0
  47. {nextrec-0.4.15 → nextrec-0.4.17}/assets/mmoe_tutorial.png +0 -0
  48. {nextrec-0.4.15 → nextrec-0.4.17}/assets/nextrec_diagram.png +0 -0
  49. {nextrec-0.4.15 → nextrec-0.4.17}/assets/test data.png +0 -0
  50. {nextrec-0.4.15 → nextrec-0.4.17}/dataset/ctcvr_task.csv +0 -0
  51. {nextrec-0.4.15 → nextrec-0.4.17}/dataset/ecommerce_task.csv +0 -0
  52. {nextrec-0.4.15 → nextrec-0.4.17}/dataset/match_task.csv +0 -0
  53. {nextrec-0.4.15 → nextrec-0.4.17}/dataset/movielens_100k.csv +0 -0
  54. {nextrec-0.4.15 → nextrec-0.4.17}/dataset/multitask_task.csv +0 -0
  55. {nextrec-0.4.15 → nextrec-0.4.17}/dataset/ranking_task.csv +0 -0
  56. {nextrec-0.4.15 → nextrec-0.4.17}/docs/en/Getting started guide.md +0 -0
  57. {nextrec-0.4.15 → nextrec-0.4.17}/docs/rtd/Makefile +0 -0
  58. {nextrec-0.4.15 → nextrec-0.4.17}/docs/rtd/index.md +0 -0
  59. {nextrec-0.4.15 → nextrec-0.4.17}/docs/rtd/make.bat +0 -0
  60. {nextrec-0.4.15 → nextrec-0.4.17}/docs/rtd/modules.rst +0 -0
  61. {nextrec-0.4.15 → nextrec-0.4.17}/docs/rtd/nextrec.basic.rst +0 -0
  62. {nextrec-0.4.15 → nextrec-0.4.17}/docs/rtd/nextrec.data.rst +0 -0
  63. {nextrec-0.4.15 → nextrec-0.4.17}/docs/rtd/nextrec.loss.rst +0 -0
  64. {nextrec-0.4.15 → nextrec-0.4.17}/docs/rtd/nextrec.rst +0 -0
  65. {nextrec-0.4.15 → nextrec-0.4.17}/docs/rtd/nextrec.utils.rst +0 -0
  66. {nextrec-0.4.15 → nextrec-0.4.17}/docs/rtd/requirements.txt +0 -0
  67. {nextrec-0.4.15 → nextrec-0.4.17}/docs/zh//345/277/253/351/200/237/344/270/212/346/211/213.md" +0 -0
  68. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/__init__.py +0 -0
  69. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/basic/__init__.py +0 -0
  70. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/basic/activation.py +0 -0
  71. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/basic/callback.py +0 -0
  72. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/basic/features.py +0 -0
  73. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/basic/layers.py +0 -0
  74. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/basic/loggers.py +0 -0
  75. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/basic/session.py +0 -0
  76. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/cli.py +0 -0
  77. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/data/__init__.py +0 -0
  78. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/data/batch_utils.py +0 -0
  79. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/data/data_processing.py +0 -0
  80. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/data/data_utils.py +0 -0
  81. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/data/dataloader.py +0 -0
  82. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/data/preprocessor.py +0 -0
  83. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/loss/__init__.py +0 -0
  84. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/loss/grad_norm.py +0 -0
  85. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/loss/listwise.py +0 -0
  86. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/loss/loss_utils.py +0 -0
  87. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/loss/pairwise.py +0 -0
  88. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/loss/pointwise.py +0 -0
  89. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/generative/__init__.py +0 -0
  90. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/generative/tiger.py +0 -0
  91. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/multi_task/__init__.py +0 -0
  92. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/ranking/__init__.py +0 -0
  93. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/representation/__init__.py +0 -0
  94. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/representation/autorec.py +0 -0
  95. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/representation/bpr.py +0 -0
  96. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/representation/cl4srec.py +0 -0
  97. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/representation/lightgcn.py +0 -0
  98. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/representation/mf.py +0 -0
  99. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/representation/rqvae.py +0 -0
  100. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/representation/s3rec.py +0 -0
  101. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/retrieval/__init__.py +0 -0
  102. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/retrieval/dssm.py +0 -0
  103. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/retrieval/dssm_v2.py +0 -0
  104. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/retrieval/mind.py +0 -0
  105. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/retrieval/sdm.py +0 -0
  106. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/retrieval/youtube_dnn.py +0 -0
  107. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/sequential/hstu.py +0 -0
  108. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/models/sequential/sasrec.py +0 -0
  109. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/utils/__init__.py +0 -0
  110. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/utils/data.py +0 -0
  111. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/utils/embedding.py +0 -0
  112. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/utils/feature.py +0 -0
  113. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/utils/model.py +0 -0
  114. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec/utils/torch_utils.py +0 -0
  115. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec_cli_preset/NextRec-CLI.md +0 -0
  116. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec_cli_preset/NextRec-CLI_zh.md +0 -0
  117. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec_cli_preset/feature_config.yaml +0 -0
  118. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/afm.yaml +0 -0
  119. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/autoint.yaml +0 -0
  120. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/dcn.yaml +0 -0
  121. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/deepfm.yaml +0 -0
  122. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/din.yaml +0 -0
  123. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/esmm.yaml +0 -0
  124. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/fibinet.yaml +0 -0
  125. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/fm.yaml +0 -0
  126. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/masknet.yaml +0 -0
  127. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/mmoe.yaml +0 -0
  128. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/ple.yaml +0 -0
  129. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/pnn.yaml +0 -0
  130. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/poso.yaml +0 -0
  131. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/share_bottom.yaml +0 -0
  132. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/widedeep.yaml +0 -0
  133. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/xdeepfm.yaml +0 -0
  134. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec_cli_preset/predict_config.yaml +0 -0
  135. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec_cli_preset/predict_config_template.yaml +0 -0
  136. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec_cli_preset/train_config.yaml +0 -0
  137. {nextrec-0.4.15 → nextrec-0.4.17}/nextrec_cli_preset/train_config_template.yaml +0 -0
  138. {nextrec-0.4.15 → nextrec-0.4.17}/pytest.ini +0 -0
  139. {nextrec-0.4.15 → nextrec-0.4.17}/requirements.txt +0 -0
  140. {nextrec-0.4.15 → nextrec-0.4.17}/scripts/format_code.py +0 -0
  141. {nextrec-0.4.15 → nextrec-0.4.17}/test/__init__.py +0 -0
  142. {nextrec-0.4.15 → nextrec-0.4.17}/test/conftest.py +0 -0
  143. {nextrec-0.4.15 → nextrec-0.4.17}/test/helpers.py +0 -0
  144. {nextrec-0.4.15 → nextrec-0.4.17}/test/run_tests.py +0 -0
  145. {nextrec-0.4.15 → nextrec-0.4.17}/test/test_base_model_regularization.py +0 -0
  146. {nextrec-0.4.15 → nextrec-0.4.17}/test/test_generative_models.py +0 -0
  147. {nextrec-0.4.15 → nextrec-0.4.17}/test/test_layers.py +0 -0
  148. {nextrec-0.4.15 → nextrec-0.4.17}/test/test_losses.py +0 -0
  149. {nextrec-0.4.15 → nextrec-0.4.17}/test/test_match_models.py +0 -0
  150. {nextrec-0.4.15 → nextrec-0.4.17}/test/test_multitask_models.py +0 -0
  151. {nextrec-0.4.15 → nextrec-0.4.17}/test/test_preprocessor.py +0 -0
  152. {nextrec-0.4.15 → nextrec-0.4.17}/test/test_ranking_models.py +0 -0
  153. {nextrec-0.4.15 → nextrec-0.4.17}/test/test_utils_console.py +0 -0
  154. {nextrec-0.4.15 → nextrec-0.4.17}/test/test_utils_data.py +0 -0
  155. {nextrec-0.4.15 → nextrec-0.4.17}/test/test_utils_embedding.py +0 -0
  156. {nextrec-0.4.15 → nextrec-0.4.17}/test_requirements.txt +0 -0
  157. {nextrec-0.4.15 → nextrec-0.4.17}/tutorials/distributed/example_distributed_training.py +0 -0
  158. {nextrec-0.4.15 → nextrec-0.4.17}/tutorials/distributed/example_distributed_training_large_dataset.py +0 -0
  159. {nextrec-0.4.15 → nextrec-0.4.17}/tutorials/example_multitask.py +0 -0
  160. {nextrec-0.4.15 → nextrec-0.4.17}/tutorials/example_ranking_din.py +0 -0
  161. {nextrec-0.4.15 → nextrec-0.4.17}/tutorials/movielen_match_dssm.py +0 -0
  162. {nextrec-0.4.15 → nextrec-0.4.17}/tutorials/movielen_ranking_deepfm.py +0 -0
  163. {nextrec-0.4.15 → nextrec-0.4.17}/tutorials/notebooks/en/Build semantic ID with RQ-VAE.ipynb +0 -0
  164. {nextrec-0.4.15 → nextrec-0.4.17}/tutorials/notebooks/en/Hands on dataprocessor.ipynb +0 -0
  165. {nextrec-0.4.15 → nextrec-0.4.17}/tutorials/notebooks/en/Hands on nextrec.ipynb +0 -0
  166. {nextrec-0.4.15 → nextrec-0.4.17}/tutorials/notebooks/zh//344/275/277/347/224/250RQ-VAE/346/236/204/345/273/272/350/257/255/344/271/211ID.ipynb" +0 -0
  167. {nextrec-0.4.15 → nextrec-0.4.17}/tutorials/notebooks/zh//345/246/202/344/275/225/344/275/277/347/224/250DataProcessor/350/277/233/350/241/214/351/242/204/345/244/204/347/220/206.ipynb" +0 -0
  168. {nextrec-0.4.15 → nextrec-0.4.17}/tutorials/notebooks/zh//345/277/253/351/200/237/345/205/245/351/227/250nextrec.ipynb" +0 -0
  169. {nextrec-0.4.15 → nextrec-0.4.17}/tutorials/run_all_match_models.py +0 -0
  170. {nextrec-0.4.15 → nextrec-0.4.17}/tutorials/run_all_multitask_models.py +0 -0
  171. {nextrec-0.4.15 → nextrec-0.4.17}/tutorials/run_all_ranking_models.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nextrec
3
- Version: 0.4.15
3
+ Version: 0.4.17
4
4
  Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
5
5
  Project-URL: Homepage, https://github.com/zerolovesea/NextRec
6
6
  Project-URL: Repository, https://github.com/zerolovesea/NextRec
@@ -68,7 +68,7 @@ Description-Content-Type: text/markdown
68
68
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
69
69
 
70
70
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
71
- ![Version](https://img.shields.io/badge/Version-0.4.15-orange.svg)
71
+ ![Version](https://img.shields.io/badge/Version-0.4.17-orange.svg)
72
72
 
73
73
 
74
74
  中文文档 | [English Version](README_en.md)
@@ -102,7 +102,7 @@ NextRec是一个基于PyTorch的现代推荐系统框架,旨在为研究工程
102
102
 
103
103
  ## NextRec近期进展
104
104
 
105
- - **21/12/2025** 在v0.4.15中加入了对[GradNorm](/nextrec/loss/grad_norm.py)的支持,通过compile的`loss_weight='grad_norm'`进行配置
105
+ - **21/12/2025** 在v0.4.16中加入了对[GradNorm](/nextrec/loss/grad_norm.py)的支持,通过compile的`loss_weight='grad_norm'`进行配置
106
106
  - **12/12/2025** 在v0.4.9中加入了[RQ-VAE](/nextrec/models/representation/rqvae.py)模块。配套的[数据集](/dataset/ecommerce_task.csv)和[代码](tutorials/notebooks/zh/使用RQ-VAE构建语义ID.ipynb)已经同步在仓库中
107
107
  - **07/12/2025** 发布了NextRec CLI命令行工具,它允许用户根据配置文件进行一键训练和推理,我们提供了相关的[教程](/nextrec_cli_preset/NextRec-CLI_zh.md)和[教学代码](/nextrec_cli_preset)
108
108
  - **03/12/2025** NextRec获得了100颗🌟!感谢大家的支持
@@ -244,11 +244,11 @@ nextrec --mode=train --train_config=path/to/train_config.yaml
244
244
  nextrec --mode=predict --predict_config=path/to/predict_config.yaml
245
245
  ```
246
246
 
247
- > 截止当前版本0.4.15,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
247
+ > 截止当前版本0.4.17,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
248
248
 
249
249
  ## 兼容平台
250
250
 
251
- 当前最新版本为0.4.15,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
251
+ 当前最新版本为0.4.17,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
252
252
 
253
253
  | 平台 | 配置 |
254
254
  |------|------|
@@ -9,7 +9,7 @@
9
9
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
10
10
 
11
11
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
12
- ![Version](https://img.shields.io/badge/Version-0.4.15-orange.svg)
12
+ ![Version](https://img.shields.io/badge/Version-0.4.17-orange.svg)
13
13
 
14
14
 
15
15
  中文文档 | [English Version](README_en.md)
@@ -43,7 +43,7 @@ NextRec是一个基于PyTorch的现代推荐系统框架,旨在为研究工程
43
43
 
44
44
  ## NextRec近期进展
45
45
 
46
- - **21/12/2025** 在v0.4.15中加入了对[GradNorm](/nextrec/loss/grad_norm.py)的支持,通过compile的`loss_weight='grad_norm'`进行配置
46
+ - **21/12/2025** 在v0.4.16中加入了对[GradNorm](/nextrec/loss/grad_norm.py)的支持,通过compile的`loss_weight='grad_norm'`进行配置
47
47
  - **12/12/2025** 在v0.4.9中加入了[RQ-VAE](/nextrec/models/representation/rqvae.py)模块。配套的[数据集](/dataset/ecommerce_task.csv)和[代码](tutorials/notebooks/zh/使用RQ-VAE构建语义ID.ipynb)已经同步在仓库中
48
48
  - **07/12/2025** 发布了NextRec CLI命令行工具,它允许用户根据配置文件进行一键训练和推理,我们提供了相关的[教程](/nextrec_cli_preset/NextRec-CLI_zh.md)和[教学代码](/nextrec_cli_preset)
49
49
  - **03/12/2025** NextRec获得了100颗🌟!感谢大家的支持
@@ -185,11 +185,11 @@ nextrec --mode=train --train_config=path/to/train_config.yaml
185
185
  nextrec --mode=predict --predict_config=path/to/predict_config.yaml
186
186
  ```
187
187
 
188
- > 截止当前版本0.4.15,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
188
+ > 截止当前版本0.4.17,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
189
189
 
190
190
  ## 兼容平台
191
191
 
192
- 当前最新版本为0.4.15,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
192
+ 当前最新版本为0.4.17,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
193
193
 
194
194
  | 平台 | 配置 |
195
195
  |------|------|
@@ -9,7 +9,7 @@
9
9
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
10
10
 
11
11
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
12
- ![Version](https://img.shields.io/badge/Version-0.4.15-orange.svg)
12
+ ![Version](https://img.shields.io/badge/Version-0.4.17-orange.svg)
13
13
 
14
14
  English | [中文文档](README.md)
15
15
 
@@ -44,7 +44,7 @@ NextRec is a modern recommendation framework built on PyTorch, delivering a unif
44
44
 
45
45
  ## NextRec Progress
46
46
 
47
- - **21/12/2025** Added support for [GradNorm](/nextrec/loss/grad_norm.py) in v0.4.15, configurable via `loss_weight='grad_norm'` in the compile method
47
+ - **21/12/2025** Added support for [GradNorm](/nextrec/loss/grad_norm.py) in v0.4.16, configurable via `loss_weight='grad_norm'` in the compile method
48
48
  - **12/12/2025** Added [RQ-VAE](/nextrec/models/representation/rqvae.py), a common module for generative retrieval in v0.4.9. Paired [dataset](/dataset/ecommerce_task.csv) and [notebook code](tutorials/notebooks/en/Build%20semantic%20ID%20with%20RQ-VAE.ipynb) are available.
49
49
  - **07/12/2025** Released the NextRec CLI tool to run training/inference from configs. See the [guide](/nextrec_cli_preset/NextRec-CLI.md) and [reference code](/nextrec_cli_preset).
50
50
  - **03/12/2025** NextRec reached 100 ⭐—thanks for the support!
@@ -188,11 +188,11 @@ nextrec --mode=train --train_config=path/to/train_config.yaml
188
188
  nextrec --mode=predict --predict_config=path/to/predict_config.yaml
189
189
  ```
190
190
 
191
- > As of version 0.4.15, NextRec CLI supports single-machine training; distributed training features are currently under development.
191
+ > As of version 0.4.17, NextRec CLI supports single-machine training; distributed training features are currently under development.
192
192
 
193
193
  ## Platform Compatibility
194
194
 
195
- The current version is 0.4.15. All models and test code have been validated on the following platforms. If you encounter compatibility issues, please report them in the issue tracker with your system version:
195
+ The current version is 0.4.17. All models and test code have been validated on the following platforms. If you encounter compatibility issues, please report them in the issue tracker with your system version:
196
196
 
197
197
  | Platform | Configuration |
198
198
  |----------|---------------|
@@ -11,7 +11,7 @@ sys.path.insert(0, str(PROJECT_ROOT / "nextrec"))
11
11
  project = "NextRec"
12
12
  copyright = "2025, Yang Zhou"
13
13
  author = "Yang Zhou"
14
- release = "0.4.15"
14
+ release = "0.4.17"
15
15
 
16
16
  extensions = [
17
17
  "myst_parser",
@@ -0,0 +1 @@
1
+ __version__ = "0.4.17"
@@ -0,0 +1,101 @@
1
+ """
2
+ Task head implementations for NextRec models.
3
+
4
+ Date: create on 23/12/2025
5
+ Author: Yang Zhou, zyaztec@gmail.com
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Literal
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ from nextrec.basic.layers import PredictionLayer
17
+
18
+
19
+ class TaskHead(nn.Module):
20
+ """
21
+ Unified task head for ranking/regression/multi-task outputs.
22
+
23
+ This wraps PredictionLayer so models can depend on a "Head" abstraction
24
+ without changing their existing forward signatures.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ task_type: str | list[str] = "binary",
30
+ task_dims: int | list[int] | None = None,
31
+ use_bias: bool = True,
32
+ return_logits: bool = False,
33
+ ) -> None:
34
+ super().__init__()
35
+ self.prediction = PredictionLayer(
36
+ task_type=task_type,
37
+ task_dims=task_dims,
38
+ use_bias=use_bias,
39
+ return_logits=return_logits,
40
+ )
41
+ # Expose commonly used attributes for compatibility with PredictionLayer.
42
+ self.task_types = self.prediction.task_types
43
+ self.task_dims = self.prediction.task_dims
44
+ self.task_slices = self.prediction.task_slices
45
+ self.total_dim = self.prediction.total_dim
46
+
47
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
48
+ return self.prediction(x)
49
+
50
+
51
+ class RetrievalHead(nn.Module):
52
+ """
53
+ Retrieval head for two-tower models.
54
+
55
+ It computes similarity for pointwise training/inference, and returns
56
+ raw embeddings for in-batch negative sampling in pairwise/listwise modes.
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
62
+ temperature: float = 1.0,
63
+ training_mode: Literal["pointwise", "pairwise", "listwise"] = "pointwise",
64
+ apply_sigmoid: bool = True,
65
+ ) -> None:
66
+ super().__init__()
67
+ self.similarity_metric = similarity_metric
68
+ self.temperature = temperature
69
+ self.training_mode = training_mode
70
+ self.apply_sigmoid = apply_sigmoid
71
+
72
+ def forward(
73
+ self,
74
+ user_emb: torch.Tensor,
75
+ item_emb: torch.Tensor,
76
+ similarity_fn=None,
77
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
78
+ if self.training and self.training_mode in {"pairwise", "listwise"}:
79
+ return user_emb, item_emb
80
+
81
+ if similarity_fn is not None:
82
+ similarity = similarity_fn(user_emb, item_emb)
83
+ else:
84
+ if user_emb.dim() == 2 and item_emb.dim() == 3:
85
+ user_emb = user_emb.unsqueeze(1)
86
+
87
+ if self.similarity_metric == "dot":
88
+ similarity = torch.sum(user_emb * item_emb, dim=-1)
89
+ elif self.similarity_metric == "cosine":
90
+ similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
91
+ elif self.similarity_metric == "euclidean":
92
+ similarity = -torch.sum((user_emb - item_emb) ** 2, dim=-1)
93
+ else:
94
+ raise ValueError(
95
+ f"Unknown similarity metric: {self.similarity_metric}"
96
+ )
97
+
98
+ similarity = similarity / self.temperature
99
+ if self.training_mode == "pointwise" and self.apply_sigmoid:
100
+ return torch.sigmoid(similarity)
101
+ return similarity
@@ -77,6 +77,8 @@ def check_user_id(*metric_sources: Any) -> bool:
77
77
 
78
78
  def compute_ks(y_true: np.ndarray, y_pred: np.ndarray) -> float:
79
79
  """Compute Kolmogorov-Smirnov statistic."""
80
+ y_true = np.asarray(y_true).reshape(-1)
81
+ y_pred = np.asarray(y_pred).reshape(-1)
80
82
  sorted_indices = np.argsort(y_pred)[::-1]
81
83
  y_true_sorted = y_true[sorted_indices]
82
84
 
@@ -38,6 +38,7 @@ from nextrec.basic.features import (
38
38
  SequenceFeature,
39
39
  SparseFeature,
40
40
  )
41
+ from nextrec.basic.heads import RetrievalHead
41
42
  from nextrec.basic.loggers import TrainingLogger, colorize, format_kv, setup_logger
42
43
  from nextrec.basic.metrics import check_user_id, configure_metrics, evaluate_metrics
43
44
  from nextrec.basic.session import create_session, resolve_save_path
@@ -481,7 +482,7 @@ class BaseModel(FeatureSet, nn.Module):
481
482
  "[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup."
482
483
  )
483
484
  loss_weights = loss_weights[0]
484
- self.loss_weights = [float(loss_weights)] # type: ignore
485
+ self.loss_weights = [float(loss_weights)] # type: ignore
485
486
  else:
486
487
  if isinstance(loss_weights, (int, float)):
487
488
  weights = [float(loss_weights)] * self.nums_task
@@ -591,8 +592,8 @@ class BaseModel(FeatureSet, nn.Module):
591
592
 
592
593
  def fit(
593
594
  self,
594
- train_data = None,
595
- valid_data = None,
595
+ train_data=None,
596
+ valid_data=None,
596
597
  metrics: (
597
598
  list[str] | dict[str, list[str]] | None
598
599
  ) = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
@@ -1583,8 +1584,11 @@ class BaseModel(FeatureSet, nn.Module):
1583
1584
  else:
1584
1585
  data_loader = data
1585
1586
 
1586
- if hasattr(data_loader, 'num_workers') and data_loader.num_workers > 0:
1587
- if hasattr(data_loader.dataset, '__class__') and 'Streaming' in data_loader.dataset.__class__.__name__:
1587
+ if hasattr(data_loader, "num_workers") and data_loader.num_workers > 0:
1588
+ if (
1589
+ hasattr(data_loader.dataset, "__class__")
1590
+ and "Streaming" in data_loader.dataset.__class__.__name__
1591
+ ):
1588
1592
  logging.warning(
1589
1593
  f"[Predict Streaming Warning] Detected DataLoader with num_workers={data_loader.num_workers} "
1590
1594
  "and streaming dataset. This may cause data duplication! "
@@ -2112,6 +2116,12 @@ class BaseMatchModel(BaseModel):
2112
2116
  )
2113
2117
  self.user_feature_names = {feature.name for feature in self.user_features_all}
2114
2118
  self.item_feature_names = {feature.name for feature in self.item_features_all}
2119
+ self.head = RetrievalHead(
2120
+ similarity_metric=self.similarity_metric,
2121
+ temperature=self.temperature,
2122
+ training_mode=self.training_mode,
2123
+ apply_sigmoid=True,
2124
+ )
2115
2125
 
2116
2126
  def compile(
2117
2127
  self,
@@ -2241,15 +2251,9 @@ class BaseMatchModel(BaseModel):
2241
2251
  user_emb = self.user_tower(user_input) # [B, D]
2242
2252
  item_emb = self.item_tower(item_input) # [B, D]
2243
2253
 
2244
- if self.training and self.training_mode in ["pairwise", "listwise"]:
2245
- return user_emb, item_emb
2246
-
2247
- similarity = self.compute_similarity(user_emb, item_emb) # [B]
2248
-
2249
- if self.training_mode == "pointwise":
2250
- return torch.sigmoid(similarity)
2251
- else:
2252
- return similarity
2254
+ return self.head(
2255
+ user_emb, item_emb, similarity_fn=self.compute_similarity
2256
+ )
2253
2257
 
2254
2258
  def compute_loss(self, y_pred, y_true):
2255
2259
  if self.training_mode == "pointwise":
@@ -45,7 +45,8 @@ import torch
45
45
  import torch.nn as nn
46
46
 
47
47
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
48
- from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
48
+ from nextrec.basic.layers import MLP, EmbeddingLayer
49
+ from nextrec.basic.heads import TaskHead
49
50
  from nextrec.basic.model import BaseModel
50
51
 
51
52
 
@@ -139,7 +140,7 @@ class ESMM(BaseModel):
139
140
  # CVR tower
140
141
  self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
141
142
  self.grad_norm_shared_modules = ["embedding"]
142
- self.prediction_layer = PredictionLayer(
143
+ self.prediction_layer = TaskHead(
143
144
  task_type=self.default_task, task_dims=[1, 1]
144
145
  )
145
146
  # Register regularization weights
@@ -167,4 +168,4 @@ class ESMM(BaseModel):
167
168
 
168
169
  # Output: [CTR, CTCVR], We supervise CTR with click labels and CTCVR with conversion labels
169
170
  y = torch.cat([ctr, ctcvr], dim=1) # [B, 2]
170
- return y # [B, 2], where y[:, 0] is CTR and y[:, 1] is CTCVR
171
+ return y # [B, 2], where y[:, 0] is CTR and y[:, 1] is CTCVR
@@ -46,7 +46,8 @@ import torch
46
46
  import torch.nn as nn
47
47
 
48
48
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
49
- from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
49
+ from nextrec.basic.layers import MLP, EmbeddingLayer
50
+ from nextrec.basic.heads import TaskHead
50
51
  from nextrec.basic.model import BaseModel
51
52
 
52
53
 
@@ -172,7 +173,7 @@ class MMOE(BaseModel):
172
173
  for tower_params in tower_params_list:
173
174
  tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
174
175
  self.towers.append(tower)
175
- self.prediction_layer = PredictionLayer(
176
+ self.prediction_layer = TaskHead(
176
177
  task_type=self.default_task, task_dims=[1] * self.num_tasks
177
178
  )
178
179
  # Register regularization weights
@@ -219,4 +220,4 @@ class MMOE(BaseModel):
219
220
 
220
221
  # Stack outputs: [B, num_tasks]
221
222
  y = torch.cat(task_outputs, dim=1)
222
- return self.prediction_layer(y)
223
+ return self.prediction_layer(y)
@@ -49,7 +49,8 @@ import torch
49
49
  import torch.nn as nn
50
50
 
51
51
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
52
- from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
52
+ from nextrec.basic.layers import MLP, EmbeddingLayer
53
+ from nextrec.basic.heads import TaskHead
53
54
  from nextrec.basic.model import BaseModel
54
55
  from nextrec.utils.model import get_mlp_output_dim
55
56
 
@@ -302,7 +303,7 @@ class PLE(BaseModel):
302
303
  for tower_params in tower_params_list:
303
304
  tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
304
305
  self.towers.append(tower)
305
- self.prediction_layer = PredictionLayer(
306
+ self.prediction_layer = TaskHead(
306
307
  task_type=self.default_task, task_dims=[1] * self.num_tasks
307
308
  )
308
309
  # Register regularization weights
@@ -336,4 +337,4 @@ class PLE(BaseModel):
336
337
 
337
338
  # [B, num_tasks]
338
339
  y = torch.cat(task_outputs, dim=1)
339
- return self.prediction_layer(y)
340
+ return self.prediction_layer(y)
@@ -44,7 +44,8 @@ import torch.nn.functional as F
44
44
 
45
45
  from nextrec.basic.activation import activation_layer
46
46
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
47
- from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
47
+ from nextrec.basic.layers import MLP, EmbeddingLayer
48
+ from nextrec.basic.heads import TaskHead
48
49
  from nextrec.basic.model import BaseModel
49
50
  from nextrec.utils.model import select_features
50
51
 
@@ -487,7 +488,7 @@ class POSO(BaseModel):
487
488
  self.grad_norm_shared_modules = ["embedding"]
488
489
  else:
489
490
  self.grad_norm_shared_modules = ["embedding", "mmoe"]
490
- self.prediction_layer = PredictionLayer(
491
+ self.prediction_layer = TaskHead(
491
492
  task_type=self.default_task,
492
493
  task_dims=[1] * self.num_tasks,
493
494
  )
@@ -524,4 +525,4 @@ class POSO(BaseModel):
524
525
  task_outputs.append(logit)
525
526
 
526
527
  y = torch.cat(task_outputs, dim=1)
527
- return self.prediction_layer(y)
528
+ return self.prediction_layer(y)
@@ -43,7 +43,8 @@ import torch
43
43
  import torch.nn as nn
44
44
 
45
45
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
46
- from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
46
+ from nextrec.basic.layers import MLP, EmbeddingLayer
47
+ from nextrec.basic.heads import TaskHead
47
48
  from nextrec.basic.model import BaseModel
48
49
 
49
50
 
@@ -142,7 +143,7 @@ class ShareBottom(BaseModel):
142
143
  for tower_params in tower_params_list:
143
144
  tower = MLP(input_dim=bottom_output_dim, output_layer=True, **tower_params)
144
145
  self.towers.append(tower)
145
- self.prediction_layer = PredictionLayer(
146
+ self.prediction_layer = TaskHead(
146
147
  task_type=self.default_task, task_dims=[1] * self.num_tasks
147
148
  )
148
149
  # Register regularization weights
@@ -171,4 +172,4 @@ class ShareBottom(BaseModel):
171
172
 
172
173
  # Stack outputs: [B, num_tasks]
173
174
  y = torch.cat(task_outputs, dim=1)
174
- return self.prediction_layer(y)
175
+ return self.prediction_layer(y)
@@ -40,7 +40,8 @@ import torch
40
40
  import torch.nn as nn
41
41
 
42
42
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
43
- from nextrec.basic.layers import EmbeddingLayer, InputMask, PredictionLayer
43
+ from nextrec.basic.layers import EmbeddingLayer, InputMask
44
+ from nextrec.basic.heads import TaskHead
44
45
  from nextrec.basic.model import BaseModel
45
46
 
46
47
 
@@ -141,7 +142,7 @@ class AFM(BaseModel):
141
142
  self.attention_p = nn.Linear(attention_dim, 1, bias=False)
142
143
  self.attention_dropout = nn.Dropout(attention_dropout)
143
144
  self.output_projection = nn.Linear(self.embedding_dim, 1, bias=False)
144
- self.prediction_layer = PredictionLayer(task_type=self.default_task)
145
+ self.prediction_layer = TaskHead(task_type=self.default_task)
145
146
  self.input_mask = InputMask()
146
147
 
147
148
  # Register regularization weights
@@ -243,4 +244,4 @@ class AFM(BaseModel):
243
244
  y_afm = self.output_projection(weighted_sum)
244
245
 
245
246
  y = y_linear + y_afm
246
- return self.prediction_layer(y)
247
+ return self.prediction_layer(y)
@@ -58,7 +58,8 @@ import torch
58
58
  import torch.nn as nn
59
59
 
60
60
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
61
- from nextrec.basic.layers import EmbeddingLayer, MultiHeadSelfAttention, PredictionLayer
61
+ from nextrec.basic.layers import EmbeddingLayer, MultiHeadSelfAttention
62
+ from nextrec.basic.heads import TaskHead
62
63
  from nextrec.basic.model import BaseModel
63
64
 
64
65
 
@@ -162,7 +163,7 @@ class AutoInt(BaseModel):
162
163
 
163
164
  # Final prediction layer
164
165
  self.fc = nn.Linear(num_fields * att_embedding_dim, 1)
165
- self.prediction_layer = PredictionLayer(task_type=self.default_task)
166
+ self.prediction_layer = TaskHead(task_type=self.default_task)
166
167
 
167
168
  # Register regularization weights
168
169
  self.register_regularization_weights(
@@ -206,4 +207,4 @@ class AutoInt(BaseModel):
206
207
  start_dim=1
207
208
  ) # [B, num_fields * att_embedding_dim]
208
209
  y = self.fc(attention_output_flat) # [B, 1]
209
- return self.prediction_layer(y)
210
+ return self.prediction_layer(y)
@@ -54,7 +54,8 @@ import torch
54
54
  import torch.nn as nn
55
55
 
56
56
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
57
- from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
57
+ from nextrec.basic.layers import MLP, EmbeddingLayer
58
+ from nextrec.basic.heads import TaskHead
58
59
  from nextrec.basic.model import BaseModel
59
60
 
60
61
 
@@ -163,7 +164,7 @@ class DCN(BaseModel):
163
164
  # Final layer only uses cross network output
164
165
  self.final_layer = nn.Linear(input_dim, 1)
165
166
 
166
- self.prediction_layer = PredictionLayer(task_type=self.task)
167
+ self.prediction_layer = TaskHead(task_type=self.task)
167
168
 
168
169
  # Register regularization weights
169
170
  self.register_regularization_weights(
@@ -197,4 +198,4 @@ class DCN(BaseModel):
197
198
 
198
199
  # Final prediction
199
200
  y = self.final_layer(combined)
200
- return self.prediction_layer(y)
201
+ return self.prediction_layer(y)
@@ -47,7 +47,8 @@ import torch
47
47
  import torch.nn as nn
48
48
 
49
49
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
50
- from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
50
+ from nextrec.basic.layers import MLP, EmbeddingLayer
51
+ from nextrec.basic.heads import TaskHead
51
52
  from nextrec.basic.model import BaseModel
52
53
 
53
54
 
@@ -272,7 +273,7 @@ class DCNv2(BaseModel):
272
273
  final_input_dim = input_dim
273
274
 
274
275
  self.final_layer = nn.Linear(final_input_dim, 1)
275
- self.prediction_layer = PredictionLayer(task_type=self.default_task)
276
+ self.prediction_layer = TaskHead(task_type=self.default_task)
276
277
 
277
278
  self.register_regularization_weights(
278
279
  embedding_attr="embedding",
@@ -301,4 +302,4 @@ class DCNv2(BaseModel):
301
302
  combined = cross_out
302
303
 
303
304
  logit = self.final_layer(combined)
304
- return self.prediction_layer(logit)
305
+ return self.prediction_layer(logit)
@@ -45,7 +45,8 @@ embedding,无需手工构造交叉特征即可端到端训练,常用于 CTR/
45
45
  import torch.nn as nn
46
46
 
47
47
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
48
- from nextrec.basic.layers import FM, LR, MLP, EmbeddingLayer, PredictionLayer
48
+ from nextrec.basic.layers import FM, LR, MLP, EmbeddingLayer
49
+ from nextrec.basic.heads import TaskHead
49
50
  from nextrec.basic.model import BaseModel
50
51
 
51
52
 
@@ -111,7 +112,7 @@ class DeepFM(BaseModel):
111
112
  self.linear = LR(fm_emb_dim_total)
112
113
  self.fm = FM(reduce_sum=True)
113
114
  self.mlp = MLP(input_dim=mlp_input_dim, **mlp_params)
114
- self.prediction_layer = PredictionLayer(task_type=self.default_task)
115
+ self.prediction_layer = TaskHead(task_type=self.default_task)
115
116
 
116
117
  # Register regularization weights
117
118
  self.register_regularization_weights(
@@ -133,4 +134,4 @@ class DeepFM(BaseModel):
133
134
  y_deep = self.mlp(input_deep) # [B, 1]
134
135
 
135
136
  y = y_linear + y_fm + y_deep
136
- return self.prediction_layer(y)
137
+ return self.prediction_layer(y)
@@ -55,8 +55,8 @@ from nextrec.basic.layers import (
55
55
  MLP,
56
56
  AttentionPoolingLayer,
57
57
  EmbeddingLayer,
58
- PredictionLayer,
59
58
  )
59
+ from nextrec.basic.heads import TaskHead
60
60
  from nextrec.basic.model import BaseModel
61
61
 
62
62
 
@@ -346,7 +346,7 @@ class DIEN(BaseModel):
346
346
  )
347
347
 
348
348
  self.mlp = MLP(input_dim=mlp_input_dim, **mlp_params)
349
- self.prediction_layer = PredictionLayer(task_type=self.task)
349
+ self.prediction_layer = TaskHead(task_type=self.task)
350
350
 
351
351
  self.register_regularization_weights(
352
352
  embedding_attr="embedding",
@@ -55,8 +55,8 @@ from nextrec.basic.layers import (
55
55
  MLP,
56
56
  AttentionPoolingLayer,
57
57
  EmbeddingLayer,
58
- PredictionLayer,
59
58
  )
59
+ from nextrec.basic.heads import TaskHead
60
60
  from nextrec.basic.model import BaseModel
61
61
 
62
62
 
@@ -173,7 +173,7 @@ class DIN(BaseModel):
173
173
 
174
174
  # MLP for final prediction
175
175
  self.mlp = MLP(input_dim=mlp_input_dim, **mlp_params)
176
- self.prediction_layer = PredictionLayer(task_type=self.task)
176
+ self.prediction_layer = TaskHead(task_type=self.task)
177
177
 
178
178
  # Register regularization weights
179
179
  self.register_regularization_weights(
@@ -38,7 +38,8 @@ import torch.nn as nn
38
38
  import torch.nn.functional as F
39
39
 
40
40
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
41
- from nextrec.basic.layers import LR, EmbeddingLayer, PredictionLayer
41
+ from nextrec.basic.layers import LR, EmbeddingLayer
42
+ from nextrec.basic.heads import TaskHead
42
43
  from nextrec.basic.model import BaseModel
43
44
 
44
45
 
@@ -295,7 +296,7 @@ class EulerNet(BaseModel):
295
296
  else:
296
297
  self.linear = None
297
298
 
298
- self.prediction_layer = PredictionLayer(task_type=self.task)
299
+ self.prediction_layer = TaskHead(task_type=self.task)
299
300
 
300
301
  modules = ["mapping", "layers", "w", "w_im"]
301
302
  if self.use_linear:
@@ -331,4 +332,4 @@ class EulerNet(BaseModel):
331
332
  r, p = layer(r, p)
332
333
  r_flat = r.reshape(r.size(0), self.num_orders * self.embedding_dim)
333
334
  p_flat = p.reshape(p.size(0), self.num_orders * self.embedding_dim)
334
- return self.w(r_flat) + self.w_im(p_flat)
335
+ return self.w(r_flat) + self.w_im(p_flat)
@@ -43,7 +43,8 @@ import torch
43
43
  import torch.nn as nn
44
44
 
45
45
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
46
- from nextrec.basic.layers import AveragePooling, InputMask, PredictionLayer, SumPooling
46
+ from nextrec.basic.layers import AveragePooling, InputMask, SumPooling
47
+ from nextrec.basic.heads import TaskHead
47
48
  from nextrec.basic.model import BaseModel
48
49
  from nextrec.utils.torch_utils import get_initializer
49
50
 
@@ -140,7 +141,7 @@ class FFM(BaseModel):
140
141
  nn.Linear(dense_input_dim, 1, bias=True) if dense_input_dim > 0 else None
141
142
  )
142
143
 
143
- self.prediction_layer = PredictionLayer(task_type=self.task)
144
+ self.prediction_layer = TaskHead(task_type=self.task)
144
145
  self.input_mask = InputMask()
145
146
  self.mean_pool = AveragePooling()
146
147
  self.sum_pool = SumPooling()
@@ -272,4 +273,4 @@ class FFM(BaseModel):
272
273
  )
273
274
 
274
275
  y = y_linear + y_interaction
275
- return self.prediction_layer(y)
276
+ return self.prediction_layer(y)