nextrec 0.4.16__tar.gz → 0.4.17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. {nextrec-0.4.16 → nextrec-0.4.17}/PKG-INFO +4 -4
  2. {nextrec-0.4.16 → nextrec-0.4.17}/README.md +3 -3
  3. {nextrec-0.4.16 → nextrec-0.4.17}/README_en.md +3 -3
  4. {nextrec-0.4.16 → nextrec-0.4.17}/docs/rtd/conf.py +1 -1
  5. nextrec-0.4.17/nextrec/__version__.py +1 -0
  6. nextrec-0.4.17/nextrec/basic/heads.py +101 -0
  7. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/basic/model.py +10 -9
  8. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/multi_task/esmm.py +4 -3
  9. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/multi_task/mmoe.py +4 -3
  10. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/multi_task/ple.py +4 -3
  11. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/multi_task/poso.py +4 -3
  12. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/multi_task/share_bottom.py +4 -3
  13. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/afm.py +4 -3
  14. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/autoint.py +4 -3
  15. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/dcn.py +4 -3
  16. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/dcn_v2.py +4 -3
  17. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/deepfm.py +4 -3
  18. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/dien.py +2 -2
  19. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/din.py +2 -2
  20. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/eulernet.py +4 -3
  21. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/ffm.py +4 -3
  22. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/fibinet.py +2 -2
  23. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/fm.py +4 -3
  24. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/lr.py +4 -3
  25. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/masknet.py +4 -3
  26. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/pnn.py +4 -3
  27. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/widedeep.py +4 -3
  28. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/xdeepfm.py +4 -3
  29. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/utils/console.py +19 -5
  30. {nextrec-0.4.16 → nextrec-0.4.17}/pyproject.toml +1 -1
  31. nextrec-0.4.16/nextrec/__version__.py +0 -1
  32. {nextrec-0.4.16 → nextrec-0.4.17}/.github/workflows/publish.yml +0 -0
  33. {nextrec-0.4.16 → nextrec-0.4.17}/.github/workflows/tests.yml +0 -0
  34. {nextrec-0.4.16 → nextrec-0.4.17}/.gitignore +0 -0
  35. {nextrec-0.4.16 → nextrec-0.4.17}/.readthedocs.yaml +0 -0
  36. {nextrec-0.4.16 → nextrec-0.4.17}/CODE_OF_CONDUCT.md +0 -0
  37. {nextrec-0.4.16 → nextrec-0.4.17}/CONTRIBUTING.md +0 -0
  38. {nextrec-0.4.16 → nextrec-0.4.17}/LICENSE +0 -0
  39. {nextrec-0.4.16 → nextrec-0.4.17}/MANIFEST.in +0 -0
  40. {nextrec-0.4.16 → nextrec-0.4.17}/assets/Feature Configuration.png +0 -0
  41. {nextrec-0.4.16 → nextrec-0.4.17}/assets/Model Parameters.png +0 -0
  42. {nextrec-0.4.16 → nextrec-0.4.17}/assets/Training Configuration.png +0 -0
  43. {nextrec-0.4.16 → nextrec-0.4.17}/assets/Training logs.png +0 -0
  44. {nextrec-0.4.16 → nextrec-0.4.17}/assets/logo.png +0 -0
  45. {nextrec-0.4.16 → nextrec-0.4.17}/assets/mmoe_tutorial.png +0 -0
  46. {nextrec-0.4.16 → nextrec-0.4.17}/assets/nextrec_diagram.png +0 -0
  47. {nextrec-0.4.16 → nextrec-0.4.17}/assets/test data.png +0 -0
  48. {nextrec-0.4.16 → nextrec-0.4.17}/dataset/ctcvr_task.csv +0 -0
  49. {nextrec-0.4.16 → nextrec-0.4.17}/dataset/ecommerce_task.csv +0 -0
  50. {nextrec-0.4.16 → nextrec-0.4.17}/dataset/match_task.csv +0 -0
  51. {nextrec-0.4.16 → nextrec-0.4.17}/dataset/movielens_100k.csv +0 -0
  52. {nextrec-0.4.16 → nextrec-0.4.17}/dataset/multitask_task.csv +0 -0
  53. {nextrec-0.4.16 → nextrec-0.4.17}/dataset/ranking_task.csv +0 -0
  54. {nextrec-0.4.16 → nextrec-0.4.17}/docs/en/Getting started guide.md +0 -0
  55. {nextrec-0.4.16 → nextrec-0.4.17}/docs/rtd/Makefile +0 -0
  56. {nextrec-0.4.16 → nextrec-0.4.17}/docs/rtd/index.md +0 -0
  57. {nextrec-0.4.16 → nextrec-0.4.17}/docs/rtd/make.bat +0 -0
  58. {nextrec-0.4.16 → nextrec-0.4.17}/docs/rtd/modules.rst +0 -0
  59. {nextrec-0.4.16 → nextrec-0.4.17}/docs/rtd/nextrec.basic.rst +0 -0
  60. {nextrec-0.4.16 → nextrec-0.4.17}/docs/rtd/nextrec.data.rst +0 -0
  61. {nextrec-0.4.16 → nextrec-0.4.17}/docs/rtd/nextrec.loss.rst +0 -0
  62. {nextrec-0.4.16 → nextrec-0.4.17}/docs/rtd/nextrec.rst +0 -0
  63. {nextrec-0.4.16 → nextrec-0.4.17}/docs/rtd/nextrec.utils.rst +0 -0
  64. {nextrec-0.4.16 → nextrec-0.4.17}/docs/rtd/requirements.txt +0 -0
  65. {nextrec-0.4.16 → nextrec-0.4.17}/docs/zh//345/277/253/351/200/237/344/270/212/346/211/213.md" +0 -0
  66. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/__init__.py +0 -0
  67. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/basic/__init__.py +0 -0
  68. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/basic/activation.py +0 -0
  69. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/basic/callback.py +0 -0
  70. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/basic/features.py +0 -0
  71. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/basic/layers.py +0 -0
  72. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/basic/loggers.py +0 -0
  73. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/basic/metrics.py +0 -0
  74. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/basic/session.py +0 -0
  75. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/cli.py +0 -0
  76. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/data/__init__.py +0 -0
  77. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/data/batch_utils.py +0 -0
  78. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/data/data_processing.py +0 -0
  79. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/data/data_utils.py +0 -0
  80. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/data/dataloader.py +0 -0
  81. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/data/preprocessor.py +0 -0
  82. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/loss/__init__.py +0 -0
  83. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/loss/grad_norm.py +0 -0
  84. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/loss/listwise.py +0 -0
  85. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/loss/loss_utils.py +0 -0
  86. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/loss/pairwise.py +0 -0
  87. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/loss/pointwise.py +0 -0
  88. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/generative/__init__.py +0 -0
  89. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/generative/tiger.py +0 -0
  90. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/multi_task/__init__.py +0 -0
  91. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/__init__.py +0 -0
  92. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/representation/__init__.py +0 -0
  93. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/representation/autorec.py +0 -0
  94. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/representation/bpr.py +0 -0
  95. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/representation/cl4srec.py +0 -0
  96. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/representation/lightgcn.py +0 -0
  97. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/representation/mf.py +0 -0
  98. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/representation/rqvae.py +0 -0
  99. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/representation/s3rec.py +0 -0
  100. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/retrieval/__init__.py +0 -0
  101. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/retrieval/dssm.py +0 -0
  102. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/retrieval/dssm_v2.py +0 -0
  103. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/retrieval/mind.py +0 -0
  104. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/retrieval/sdm.py +0 -0
  105. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/retrieval/youtube_dnn.py +0 -0
  106. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/sequential/hstu.py +0 -0
  107. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/sequential/sasrec.py +0 -0
  108. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/utils/__init__.py +0 -0
  109. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/utils/config.py +0 -0
  110. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/utils/data.py +0 -0
  111. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/utils/embedding.py +0 -0
  112. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/utils/feature.py +0 -0
  113. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/utils/model.py +0 -0
  114. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/utils/torch_utils.py +0 -0
  115. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/NextRec-CLI.md +0 -0
  116. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/NextRec-CLI_zh.md +0 -0
  117. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/feature_config.yaml +0 -0
  118. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/afm.yaml +0 -0
  119. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/autoint.yaml +0 -0
  120. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/dcn.yaml +0 -0
  121. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/deepfm.yaml +0 -0
  122. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/din.yaml +0 -0
  123. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/esmm.yaml +0 -0
  124. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/fibinet.yaml +0 -0
  125. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/fm.yaml +0 -0
  126. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/masknet.yaml +0 -0
  127. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/mmoe.yaml +0 -0
  128. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/ple.yaml +0 -0
  129. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/pnn.yaml +0 -0
  130. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/poso.yaml +0 -0
  131. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/share_bottom.yaml +0 -0
  132. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/widedeep.yaml +0 -0
  133. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/xdeepfm.yaml +0 -0
  134. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/predict_config.yaml +0 -0
  135. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/predict_config_template.yaml +0 -0
  136. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/train_config.yaml +0 -0
  137. {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/train_config_template.yaml +0 -0
  138. {nextrec-0.4.16 → nextrec-0.4.17}/pytest.ini +0 -0
  139. {nextrec-0.4.16 → nextrec-0.4.17}/requirements.txt +0 -0
  140. {nextrec-0.4.16 → nextrec-0.4.17}/scripts/format_code.py +0 -0
  141. {nextrec-0.4.16 → nextrec-0.4.17}/test/__init__.py +0 -0
  142. {nextrec-0.4.16 → nextrec-0.4.17}/test/conftest.py +0 -0
  143. {nextrec-0.4.16 → nextrec-0.4.17}/test/helpers.py +0 -0
  144. {nextrec-0.4.16 → nextrec-0.4.17}/test/run_tests.py +0 -0
  145. {nextrec-0.4.16 → nextrec-0.4.17}/test/test_base_model_regularization.py +0 -0
  146. {nextrec-0.4.16 → nextrec-0.4.17}/test/test_generative_models.py +0 -0
  147. {nextrec-0.4.16 → nextrec-0.4.17}/test/test_layers.py +0 -0
  148. {nextrec-0.4.16 → nextrec-0.4.17}/test/test_losses.py +0 -0
  149. {nextrec-0.4.16 → nextrec-0.4.17}/test/test_match_models.py +0 -0
  150. {nextrec-0.4.16 → nextrec-0.4.17}/test/test_multitask_models.py +0 -0
  151. {nextrec-0.4.16 → nextrec-0.4.17}/test/test_preprocessor.py +0 -0
  152. {nextrec-0.4.16 → nextrec-0.4.17}/test/test_ranking_models.py +0 -0
  153. {nextrec-0.4.16 → nextrec-0.4.17}/test/test_utils_console.py +0 -0
  154. {nextrec-0.4.16 → nextrec-0.4.17}/test/test_utils_data.py +0 -0
  155. {nextrec-0.4.16 → nextrec-0.4.17}/test/test_utils_embedding.py +0 -0
  156. {nextrec-0.4.16 → nextrec-0.4.17}/test_requirements.txt +0 -0
  157. {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/distributed/example_distributed_training.py +0 -0
  158. {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/distributed/example_distributed_training_large_dataset.py +0 -0
  159. {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/example_multitask.py +0 -0
  160. {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/example_ranking_din.py +0 -0
  161. {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/movielen_match_dssm.py +0 -0
  162. {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/movielen_ranking_deepfm.py +0 -0
  163. {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/notebooks/en/Build semantic ID with RQ-VAE.ipynb +0 -0
  164. {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/notebooks/en/Hands on dataprocessor.ipynb +0 -0
  165. {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/notebooks/en/Hands on nextrec.ipynb +0 -0
  166. {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/notebooks/zh//344/275/277/347/224/250RQ-VAE/346/236/204/345/273/272/350/257/255/344/271/211ID.ipynb" +0 -0
  167. {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/notebooks/zh//345/246/202/344/275/225/344/275/277/347/224/250DataProcessor/350/277/233/350/241/214/351/242/204/345/244/204/347/220/206.ipynb" +0 -0
  168. {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/notebooks/zh//345/277/253/351/200/237/345/205/245/351/227/250nextrec.ipynb" +0 -0
  169. {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/run_all_match_models.py +0 -0
  170. {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/run_all_multitask_models.py +0 -0
  171. {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/run_all_ranking_models.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nextrec
3
- Version: 0.4.16
3
+ Version: 0.4.17
4
4
  Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
5
5
  Project-URL: Homepage, https://github.com/zerolovesea/NextRec
6
6
  Project-URL: Repository, https://github.com/zerolovesea/NextRec
@@ -68,7 +68,7 @@ Description-Content-Type: text/markdown
68
68
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
69
69
 
70
70
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
71
- ![Version](https://img.shields.io/badge/Version-0.4.16-orange.svg)
71
+ ![Version](https://img.shields.io/badge/Version-0.4.17-orange.svg)
72
72
 
73
73
 
74
74
  中文文档 | [English Version](README_en.md)
@@ -244,11 +244,11 @@ nextrec --mode=train --train_config=path/to/train_config.yaml
244
244
  nextrec --mode=predict --predict_config=path/to/predict_config.yaml
245
245
  ```
246
246
 
247
- > 截止当前版本0.4.16,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
247
+ > 截止当前版本0.4.17,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
248
248
 
249
249
  ## 兼容平台
250
250
 
251
- 当前最新版本为0.4.16,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
251
+ 当前最新版本为0.4.17,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
252
252
 
253
253
  | 平台 | 配置 |
254
254
  |------|------|
@@ -9,7 +9,7 @@
9
9
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
10
10
 
11
11
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
12
- ![Version](https://img.shields.io/badge/Version-0.4.16-orange.svg)
12
+ ![Version](https://img.shields.io/badge/Version-0.4.17-orange.svg)
13
13
 
14
14
 
15
15
  中文文档 | [English Version](README_en.md)
@@ -185,11 +185,11 @@ nextrec --mode=train --train_config=path/to/train_config.yaml
185
185
  nextrec --mode=predict --predict_config=path/to/predict_config.yaml
186
186
  ```
187
187
 
188
- > 截止当前版本0.4.16,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
188
+ > 截止当前版本0.4.17,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
189
189
 
190
190
  ## 兼容平台
191
191
 
192
- 当前最新版本为0.4.16,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
192
+ 当前最新版本为0.4.17,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
193
193
 
194
194
  | 平台 | 配置 |
195
195
  |------|------|
@@ -9,7 +9,7 @@
9
9
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
10
10
 
11
11
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
12
- ![Version](https://img.shields.io/badge/Version-0.4.16-orange.svg)
12
+ ![Version](https://img.shields.io/badge/Version-0.4.17-orange.svg)
13
13
 
14
14
  English | [中文文档](README.md)
15
15
 
@@ -188,11 +188,11 @@ nextrec --mode=train --train_config=path/to/train_config.yaml
188
188
  nextrec --mode=predict --predict_config=path/to/predict_config.yaml
189
189
  ```
190
190
 
191
- > As of version 0.4.16, NextRec CLI supports single-machine training; distributed training features are currently under development.
191
+ > As of version 0.4.17, NextRec CLI supports single-machine training; distributed training features are currently under development.
192
192
 
193
193
  ## Platform Compatibility
194
194
 
195
- The current version is 0.4.16. All models and test code have been validated on the following platforms. If you encounter compatibility issues, please report them in the issue tracker with your system version:
195
+ The current version is 0.4.17. All models and test code have been validated on the following platforms. If you encounter compatibility issues, please report them in the issue tracker with your system version:
196
196
 
197
197
  | Platform | Configuration |
198
198
  |----------|---------------|
@@ -11,7 +11,7 @@ sys.path.insert(0, str(PROJECT_ROOT / "nextrec"))
11
11
  project = "NextRec"
12
12
  copyright = "2025, Yang Zhou"
13
13
  author = "Yang Zhou"
14
- release = "0.4.16"
14
+ release = "0.4.17"
15
15
 
16
16
  extensions = [
17
17
  "myst_parser",
@@ -0,0 +1 @@
1
+ __version__ = "0.4.17"
@@ -0,0 +1,101 @@
1
+ """
2
+ Task head implementations for NextRec models.
3
+
4
+ Date: create on 23/12/2025
5
+ Author: Yang Zhou, zyaztec@gmail.com
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Literal
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ from nextrec.basic.layers import PredictionLayer
17
+
18
+
19
+ class TaskHead(nn.Module):
20
+ """
21
+ Unified task head for ranking/regression/multi-task outputs.
22
+
23
+ This wraps PredictionLayer so models can depend on a "Head" abstraction
24
+ without changing their existing forward signatures.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ task_type: str | list[str] = "binary",
30
+ task_dims: int | list[int] | None = None,
31
+ use_bias: bool = True,
32
+ return_logits: bool = False,
33
+ ) -> None:
34
+ super().__init__()
35
+ self.prediction = PredictionLayer(
36
+ task_type=task_type,
37
+ task_dims=task_dims,
38
+ use_bias=use_bias,
39
+ return_logits=return_logits,
40
+ )
41
+ # Expose commonly used attributes for compatibility with PredictionLayer.
42
+ self.task_types = self.prediction.task_types
43
+ self.task_dims = self.prediction.task_dims
44
+ self.task_slices = self.prediction.task_slices
45
+ self.total_dim = self.prediction.total_dim
46
+
47
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
48
+ return self.prediction(x)
49
+
50
+
51
+ class RetrievalHead(nn.Module):
52
+ """
53
+ Retrieval head for two-tower models.
54
+
55
+ It computes similarity for pointwise training/inference, and returns
56
+ raw embeddings for in-batch negative sampling in pairwise/listwise modes.
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
62
+ temperature: float = 1.0,
63
+ training_mode: Literal["pointwise", "pairwise", "listwise"] = "pointwise",
64
+ apply_sigmoid: bool = True,
65
+ ) -> None:
66
+ super().__init__()
67
+ self.similarity_metric = similarity_metric
68
+ self.temperature = temperature
69
+ self.training_mode = training_mode
70
+ self.apply_sigmoid = apply_sigmoid
71
+
72
+ def forward(
73
+ self,
74
+ user_emb: torch.Tensor,
75
+ item_emb: torch.Tensor,
76
+ similarity_fn=None,
77
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
78
+ if self.training and self.training_mode in {"pairwise", "listwise"}:
79
+ return user_emb, item_emb
80
+
81
+ if similarity_fn is not None:
82
+ similarity = similarity_fn(user_emb, item_emb)
83
+ else:
84
+ if user_emb.dim() == 2 and item_emb.dim() == 3:
85
+ user_emb = user_emb.unsqueeze(1)
86
+
87
+ if self.similarity_metric == "dot":
88
+ similarity = torch.sum(user_emb * item_emb, dim=-1)
89
+ elif self.similarity_metric == "cosine":
90
+ similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
91
+ elif self.similarity_metric == "euclidean":
92
+ similarity = -torch.sum((user_emb - item_emb) ** 2, dim=-1)
93
+ else:
94
+ raise ValueError(
95
+ f"Unknown similarity metric: {self.similarity_metric}"
96
+ )
97
+
98
+ similarity = similarity / self.temperature
99
+ if self.training_mode == "pointwise" and self.apply_sigmoid:
100
+ return torch.sigmoid(similarity)
101
+ return similarity
@@ -38,6 +38,7 @@ from nextrec.basic.features import (
38
38
  SequenceFeature,
39
39
  SparseFeature,
40
40
  )
41
+ from nextrec.basic.heads import RetrievalHead
41
42
  from nextrec.basic.loggers import TrainingLogger, colorize, format_kv, setup_logger
42
43
  from nextrec.basic.metrics import check_user_id, configure_metrics, evaluate_metrics
43
44
  from nextrec.basic.session import create_session, resolve_save_path
@@ -2115,6 +2116,12 @@ class BaseMatchModel(BaseModel):
2115
2116
  )
2116
2117
  self.user_feature_names = {feature.name for feature in self.user_features_all}
2117
2118
  self.item_feature_names = {feature.name for feature in self.item_features_all}
2119
+ self.head = RetrievalHead(
2120
+ similarity_metric=self.similarity_metric,
2121
+ temperature=self.temperature,
2122
+ training_mode=self.training_mode,
2123
+ apply_sigmoid=True,
2124
+ )
2118
2125
 
2119
2126
  def compile(
2120
2127
  self,
@@ -2244,15 +2251,9 @@ class BaseMatchModel(BaseModel):
2244
2251
  user_emb = self.user_tower(user_input) # [B, D]
2245
2252
  item_emb = self.item_tower(item_input) # [B, D]
2246
2253
 
2247
- if self.training and self.training_mode in ["pairwise", "listwise"]:
2248
- return user_emb, item_emb
2249
-
2250
- similarity = self.compute_similarity(user_emb, item_emb) # [B]
2251
-
2252
- if self.training_mode == "pointwise":
2253
- return torch.sigmoid(similarity)
2254
- else:
2255
- return similarity
2254
+ return self.head(
2255
+ user_emb, item_emb, similarity_fn=self.compute_similarity
2256
+ )
2256
2257
 
2257
2258
  def compute_loss(self, y_pred, y_true):
2258
2259
  if self.training_mode == "pointwise":
@@ -45,7 +45,8 @@ import torch
45
45
  import torch.nn as nn
46
46
 
47
47
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
48
- from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
48
+ from nextrec.basic.layers import MLP, EmbeddingLayer
49
+ from nextrec.basic.heads import TaskHead
49
50
  from nextrec.basic.model import BaseModel
50
51
 
51
52
 
@@ -139,7 +140,7 @@ class ESMM(BaseModel):
139
140
  # CVR tower
140
141
  self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
141
142
  self.grad_norm_shared_modules = ["embedding"]
142
- self.prediction_layer = PredictionLayer(
143
+ self.prediction_layer = TaskHead(
143
144
  task_type=self.default_task, task_dims=[1, 1]
144
145
  )
145
146
  # Register regularization weights
@@ -167,4 +168,4 @@ class ESMM(BaseModel):
167
168
 
168
169
  # Output: [CTR, CTCVR], We supervise CTR with click labels and CTCVR with conversion labels
169
170
  y = torch.cat([ctr, ctcvr], dim=1) # [B, 2]
170
- return y # [B, 2], where y[:, 0] is CTR and y[:, 1] is CTCVR
171
+ return y # [B, 2], where y[:, 0] is CTR and y[:, 1] is CTCVR
@@ -46,7 +46,8 @@ import torch
46
46
  import torch.nn as nn
47
47
 
48
48
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
49
- from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
49
+ from nextrec.basic.layers import MLP, EmbeddingLayer
50
+ from nextrec.basic.heads import TaskHead
50
51
  from nextrec.basic.model import BaseModel
51
52
 
52
53
 
@@ -172,7 +173,7 @@ class MMOE(BaseModel):
172
173
  for tower_params in tower_params_list:
173
174
  tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
174
175
  self.towers.append(tower)
175
- self.prediction_layer = PredictionLayer(
176
+ self.prediction_layer = TaskHead(
176
177
  task_type=self.default_task, task_dims=[1] * self.num_tasks
177
178
  )
178
179
  # Register regularization weights
@@ -219,4 +220,4 @@ class MMOE(BaseModel):
219
220
 
220
221
  # Stack outputs: [B, num_tasks]
221
222
  y = torch.cat(task_outputs, dim=1)
222
- return self.prediction_layer(y)
223
+ return self.prediction_layer(y)
@@ -49,7 +49,8 @@ import torch
49
49
  import torch.nn as nn
50
50
 
51
51
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
52
- from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
52
+ from nextrec.basic.layers import MLP, EmbeddingLayer
53
+ from nextrec.basic.heads import TaskHead
53
54
  from nextrec.basic.model import BaseModel
54
55
  from nextrec.utils.model import get_mlp_output_dim
55
56
 
@@ -302,7 +303,7 @@ class PLE(BaseModel):
302
303
  for tower_params in tower_params_list:
303
304
  tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
304
305
  self.towers.append(tower)
305
- self.prediction_layer = PredictionLayer(
306
+ self.prediction_layer = TaskHead(
306
307
  task_type=self.default_task, task_dims=[1] * self.num_tasks
307
308
  )
308
309
  # Register regularization weights
@@ -336,4 +337,4 @@ class PLE(BaseModel):
336
337
 
337
338
  # [B, num_tasks]
338
339
  y = torch.cat(task_outputs, dim=1)
339
- return self.prediction_layer(y)
340
+ return self.prediction_layer(y)
@@ -44,7 +44,8 @@ import torch.nn.functional as F
44
44
 
45
45
  from nextrec.basic.activation import activation_layer
46
46
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
47
- from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
47
+ from nextrec.basic.layers import MLP, EmbeddingLayer
48
+ from nextrec.basic.heads import TaskHead
48
49
  from nextrec.basic.model import BaseModel
49
50
  from nextrec.utils.model import select_features
50
51
 
@@ -487,7 +488,7 @@ class POSO(BaseModel):
487
488
  self.grad_norm_shared_modules = ["embedding"]
488
489
  else:
489
490
  self.grad_norm_shared_modules = ["embedding", "mmoe"]
490
- self.prediction_layer = PredictionLayer(
491
+ self.prediction_layer = TaskHead(
491
492
  task_type=self.default_task,
492
493
  task_dims=[1] * self.num_tasks,
493
494
  )
@@ -524,4 +525,4 @@ class POSO(BaseModel):
524
525
  task_outputs.append(logit)
525
526
 
526
527
  y = torch.cat(task_outputs, dim=1)
527
- return self.prediction_layer(y)
528
+ return self.prediction_layer(y)
@@ -43,7 +43,8 @@ import torch
43
43
  import torch.nn as nn
44
44
 
45
45
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
46
- from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
46
+ from nextrec.basic.layers import MLP, EmbeddingLayer
47
+ from nextrec.basic.heads import TaskHead
47
48
  from nextrec.basic.model import BaseModel
48
49
 
49
50
 
@@ -142,7 +143,7 @@ class ShareBottom(BaseModel):
142
143
  for tower_params in tower_params_list:
143
144
  tower = MLP(input_dim=bottom_output_dim, output_layer=True, **tower_params)
144
145
  self.towers.append(tower)
145
- self.prediction_layer = PredictionLayer(
146
+ self.prediction_layer = TaskHead(
146
147
  task_type=self.default_task, task_dims=[1] * self.num_tasks
147
148
  )
148
149
  # Register regularization weights
@@ -171,4 +172,4 @@ class ShareBottom(BaseModel):
171
172
 
172
173
  # Stack outputs: [B, num_tasks]
173
174
  y = torch.cat(task_outputs, dim=1)
174
- return self.prediction_layer(y)
175
+ return self.prediction_layer(y)
@@ -40,7 +40,8 @@ import torch
40
40
  import torch.nn as nn
41
41
 
42
42
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
43
- from nextrec.basic.layers import EmbeddingLayer, InputMask, PredictionLayer
43
+ from nextrec.basic.layers import EmbeddingLayer, InputMask
44
+ from nextrec.basic.heads import TaskHead
44
45
  from nextrec.basic.model import BaseModel
45
46
 
46
47
 
@@ -141,7 +142,7 @@ class AFM(BaseModel):
141
142
  self.attention_p = nn.Linear(attention_dim, 1, bias=False)
142
143
  self.attention_dropout = nn.Dropout(attention_dropout)
143
144
  self.output_projection = nn.Linear(self.embedding_dim, 1, bias=False)
144
- self.prediction_layer = PredictionLayer(task_type=self.default_task)
145
+ self.prediction_layer = TaskHead(task_type=self.default_task)
145
146
  self.input_mask = InputMask()
146
147
 
147
148
  # Register regularization weights
@@ -243,4 +244,4 @@ class AFM(BaseModel):
243
244
  y_afm = self.output_projection(weighted_sum)
244
245
 
245
246
  y = y_linear + y_afm
246
- return self.prediction_layer(y)
247
+ return self.prediction_layer(y)
@@ -58,7 +58,8 @@ import torch
58
58
  import torch.nn as nn
59
59
 
60
60
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
61
- from nextrec.basic.layers import EmbeddingLayer, MultiHeadSelfAttention, PredictionLayer
61
+ from nextrec.basic.layers import EmbeddingLayer, MultiHeadSelfAttention
62
+ from nextrec.basic.heads import TaskHead
62
63
  from nextrec.basic.model import BaseModel
63
64
 
64
65
 
@@ -162,7 +163,7 @@ class AutoInt(BaseModel):
162
163
 
163
164
  # Final prediction layer
164
165
  self.fc = nn.Linear(num_fields * att_embedding_dim, 1)
165
- self.prediction_layer = PredictionLayer(task_type=self.default_task)
166
+ self.prediction_layer = TaskHead(task_type=self.default_task)
166
167
 
167
168
  # Register regularization weights
168
169
  self.register_regularization_weights(
@@ -206,4 +207,4 @@ class AutoInt(BaseModel):
206
207
  start_dim=1
207
208
  ) # [B, num_fields * att_embedding_dim]
208
209
  y = self.fc(attention_output_flat) # [B, 1]
209
- return self.prediction_layer(y)
210
+ return self.prediction_layer(y)
@@ -54,7 +54,8 @@ import torch
54
54
  import torch.nn as nn
55
55
 
56
56
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
57
- from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
57
+ from nextrec.basic.layers import MLP, EmbeddingLayer
58
+ from nextrec.basic.heads import TaskHead
58
59
  from nextrec.basic.model import BaseModel
59
60
 
60
61
 
@@ -163,7 +164,7 @@ class DCN(BaseModel):
163
164
  # Final layer only uses cross network output
164
165
  self.final_layer = nn.Linear(input_dim, 1)
165
166
 
166
- self.prediction_layer = PredictionLayer(task_type=self.task)
167
+ self.prediction_layer = TaskHead(task_type=self.task)
167
168
 
168
169
  # Register regularization weights
169
170
  self.register_regularization_weights(
@@ -197,4 +198,4 @@ class DCN(BaseModel):
197
198
 
198
199
  # Final prediction
199
200
  y = self.final_layer(combined)
200
- return self.prediction_layer(y)
201
+ return self.prediction_layer(y)
@@ -47,7 +47,8 @@ import torch
47
47
  import torch.nn as nn
48
48
 
49
49
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
50
- from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
50
+ from nextrec.basic.layers import MLP, EmbeddingLayer
51
+ from nextrec.basic.heads import TaskHead
51
52
  from nextrec.basic.model import BaseModel
52
53
 
53
54
 
@@ -272,7 +273,7 @@ class DCNv2(BaseModel):
272
273
  final_input_dim = input_dim
273
274
 
274
275
  self.final_layer = nn.Linear(final_input_dim, 1)
275
- self.prediction_layer = PredictionLayer(task_type=self.default_task)
276
+ self.prediction_layer = TaskHead(task_type=self.default_task)
276
277
 
277
278
  self.register_regularization_weights(
278
279
  embedding_attr="embedding",
@@ -301,4 +302,4 @@ class DCNv2(BaseModel):
301
302
  combined = cross_out
302
303
 
303
304
  logit = self.final_layer(combined)
304
- return self.prediction_layer(logit)
305
+ return self.prediction_layer(logit)
@@ -45,7 +45,8 @@ embedding,无需手工构造交叉特征即可端到端训练,常用于 CTR/
45
45
  import torch.nn as nn
46
46
 
47
47
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
48
- from nextrec.basic.layers import FM, LR, MLP, EmbeddingLayer, PredictionLayer
48
+ from nextrec.basic.layers import FM, LR, MLP, EmbeddingLayer
49
+ from nextrec.basic.heads import TaskHead
49
50
  from nextrec.basic.model import BaseModel
50
51
 
51
52
 
@@ -111,7 +112,7 @@ class DeepFM(BaseModel):
111
112
  self.linear = LR(fm_emb_dim_total)
112
113
  self.fm = FM(reduce_sum=True)
113
114
  self.mlp = MLP(input_dim=mlp_input_dim, **mlp_params)
114
- self.prediction_layer = PredictionLayer(task_type=self.default_task)
115
+ self.prediction_layer = TaskHead(task_type=self.default_task)
115
116
 
116
117
  # Register regularization weights
117
118
  self.register_regularization_weights(
@@ -133,4 +134,4 @@ class DeepFM(BaseModel):
133
134
  y_deep = self.mlp(input_deep) # [B, 1]
134
135
 
135
136
  y = y_linear + y_fm + y_deep
136
- return self.prediction_layer(y)
137
+ return self.prediction_layer(y)
@@ -55,8 +55,8 @@ from nextrec.basic.layers import (
55
55
  MLP,
56
56
  AttentionPoolingLayer,
57
57
  EmbeddingLayer,
58
- PredictionLayer,
59
58
  )
59
+ from nextrec.basic.heads import TaskHead
60
60
  from nextrec.basic.model import BaseModel
61
61
 
62
62
 
@@ -346,7 +346,7 @@ class DIEN(BaseModel):
346
346
  )
347
347
 
348
348
  self.mlp = MLP(input_dim=mlp_input_dim, **mlp_params)
349
- self.prediction_layer = PredictionLayer(task_type=self.task)
349
+ self.prediction_layer = TaskHead(task_type=self.task)
350
350
 
351
351
  self.register_regularization_weights(
352
352
  embedding_attr="embedding",
@@ -55,8 +55,8 @@ from nextrec.basic.layers import (
55
55
  MLP,
56
56
  AttentionPoolingLayer,
57
57
  EmbeddingLayer,
58
- PredictionLayer,
59
58
  )
59
+ from nextrec.basic.heads import TaskHead
60
60
  from nextrec.basic.model import BaseModel
61
61
 
62
62
 
@@ -173,7 +173,7 @@ class DIN(BaseModel):
173
173
 
174
174
  # MLP for final prediction
175
175
  self.mlp = MLP(input_dim=mlp_input_dim, **mlp_params)
176
- self.prediction_layer = PredictionLayer(task_type=self.task)
176
+ self.prediction_layer = TaskHead(task_type=self.task)
177
177
 
178
178
  # Register regularization weights
179
179
  self.register_regularization_weights(
@@ -38,7 +38,8 @@ import torch.nn as nn
38
38
  import torch.nn.functional as F
39
39
 
40
40
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
41
- from nextrec.basic.layers import LR, EmbeddingLayer, PredictionLayer
41
+ from nextrec.basic.layers import LR, EmbeddingLayer
42
+ from nextrec.basic.heads import TaskHead
42
43
  from nextrec.basic.model import BaseModel
43
44
 
44
45
 
@@ -295,7 +296,7 @@ class EulerNet(BaseModel):
295
296
  else:
296
297
  self.linear = None
297
298
 
298
- self.prediction_layer = PredictionLayer(task_type=self.task)
299
+ self.prediction_layer = TaskHead(task_type=self.task)
299
300
 
300
301
  modules = ["mapping", "layers", "w", "w_im"]
301
302
  if self.use_linear:
@@ -331,4 +332,4 @@ class EulerNet(BaseModel):
331
332
  r, p = layer(r, p)
332
333
  r_flat = r.reshape(r.size(0), self.num_orders * self.embedding_dim)
333
334
  p_flat = p.reshape(p.size(0), self.num_orders * self.embedding_dim)
334
- return self.w(r_flat) + self.w_im(p_flat)
335
+ return self.w(r_flat) + self.w_im(p_flat)
@@ -43,7 +43,8 @@ import torch
43
43
  import torch.nn as nn
44
44
 
45
45
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
46
- from nextrec.basic.layers import AveragePooling, InputMask, PredictionLayer, SumPooling
46
+ from nextrec.basic.layers import AveragePooling, InputMask, SumPooling
47
+ from nextrec.basic.heads import TaskHead
47
48
  from nextrec.basic.model import BaseModel
48
49
  from nextrec.utils.torch_utils import get_initializer
49
50
 
@@ -140,7 +141,7 @@ class FFM(BaseModel):
140
141
  nn.Linear(dense_input_dim, 1, bias=True) if dense_input_dim > 0 else None
141
142
  )
142
143
 
143
- self.prediction_layer = PredictionLayer(task_type=self.task)
144
+ self.prediction_layer = TaskHead(task_type=self.task)
144
145
  self.input_mask = InputMask()
145
146
  self.mean_pool = AveragePooling()
146
147
  self.sum_pool = SumPooling()
@@ -272,4 +273,4 @@ class FFM(BaseModel):
272
273
  )
273
274
 
274
275
  y = y_linear + y_interaction
275
- return self.prediction_layer(y)
276
+ return self.prediction_layer(y)
@@ -50,9 +50,9 @@ from nextrec.basic.layers import (
50
50
  BiLinearInteractionLayer,
51
51
  EmbeddingLayer,
52
52
  HadamardInteractionLayer,
53
- PredictionLayer,
54
53
  SENETLayer,
55
54
  )
55
+ from nextrec.basic.heads import TaskHead
56
56
  from nextrec.basic.model import BaseModel
57
57
 
58
58
 
@@ -168,7 +168,7 @@ class FiBiNET(BaseModel):
168
168
  num_pairs = self.num_fields * (self.num_fields - 1) // 2
169
169
  interaction_dim = num_pairs * self.embedding_dim * 2
170
170
  self.mlp = MLP(input_dim=interaction_dim, **mlp_params)
171
- self.prediction_layer = PredictionLayer(task_type=self.default_task)
171
+ self.prediction_layer = TaskHead(task_type=self.default_task)
172
172
 
173
173
  # Register regularization weights
174
174
  self.register_regularization_weights(
@@ -42,7 +42,8 @@ import torch.nn as nn
42
42
 
43
43
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
44
44
  from nextrec.basic.layers import FM as FMInteraction
45
- from nextrec.basic.layers import LR, EmbeddingLayer, PredictionLayer
45
+ from nextrec.basic.heads import TaskHead
46
+ from nextrec.basic.layers import LR, EmbeddingLayer
46
47
  from nextrec.basic.model import BaseModel
47
48
 
48
49
 
@@ -105,7 +106,7 @@ class FM(BaseModel):
105
106
  fm_input_dim = sum([f.embedding_dim for f in self.fm_features])
106
107
  self.linear = LR(fm_input_dim)
107
108
  self.fm = FMInteraction(reduce_sum=True)
108
- self.prediction_layer = PredictionLayer(task_type=self.task)
109
+ self.prediction_layer = TaskHead(task_type=self.task)
109
110
 
110
111
  # Register regularization weights
111
112
  self.register_regularization_weights(
@@ -124,4 +125,4 @@ class FM(BaseModel):
124
125
  y_linear = self.linear(input_fm.flatten(start_dim=1))
125
126
  y_fm = self.fm(input_fm)
126
127
  y = y_linear + y_fm
127
- return self.prediction_layer(y)
128
+ return self.prediction_layer(y)
@@ -41,7 +41,8 @@ LR 是 CTR/排序任务中最经典的线性基线模型。它将稠密、稀疏
41
41
  import torch.nn as nn
42
42
 
43
43
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
44
- from nextrec.basic.layers import EmbeddingLayer, LR as LinearLayer, PredictionLayer
44
+ from nextrec.basic.layers import EmbeddingLayer, LR as LinearLayer
45
+ from nextrec.basic.heads import TaskHead
45
46
  from nextrec.basic.model import BaseModel
46
47
 
47
48
 
@@ -99,7 +100,7 @@ class LR(BaseModel):
99
100
  self.embedding = EmbeddingLayer(features=self.all_features)
100
101
  linear_input_dim = self.embedding.input_dim
101
102
  self.linear = LinearLayer(linear_input_dim)
102
- self.prediction_layer = PredictionLayer(task_type=self.task)
103
+ self.prediction_layer = TaskHead(task_type=self.task)
103
104
 
104
105
  self.register_regularization_weights(
105
106
  embedding_attr="embedding", include_modules=["linear"]
@@ -115,4 +116,4 @@ class LR(BaseModel):
115
116
  def forward(self, x):
116
117
  input_linear = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
117
118
  y = self.linear(input_linear)
118
- return self.prediction_layer(y)
119
+ return self.prediction_layer(y)