nextrec 0.4.22__tar.gz → 0.4.23__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (174) hide show
  1. {nextrec-0.4.22 → nextrec-0.4.23}/PKG-INFO +7 -5
  2. {nextrec-0.4.22 → nextrec-0.4.23}/README.md +6 -4
  3. {nextrec-0.4.22 → nextrec-0.4.23}/README_en.md +4 -4
  4. {nextrec-0.4.22 → nextrec-0.4.23}/docs/en/Getting started guide.md +1 -0
  5. {nextrec-0.4.22 → nextrec-0.4.23}/docs/rtd/conf.py +1 -1
  6. {nextrec-0.4.22 → nextrec-0.4.23}/docs/rtd/index.md +1 -0
  7. {nextrec-0.4.22 → nextrec-0.4.23}/docs/zh//345/277/253/351/200/237/344/270/212/346/211/213.md +1 -0
  8. nextrec-0.4.23/nextrec/__version__.py +1 -0
  9. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/basic/metrics.py +1 -2
  10. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/basic/model.py +68 -73
  11. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/basic/summary.py +36 -2
  12. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/data/preprocessor.py +137 -5
  13. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/loss/listwise.py +19 -6
  14. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/loss/pairwise.py +6 -4
  15. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/loss/pointwise.py +8 -6
  16. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/multi_task/esmm.py +3 -26
  17. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/multi_task/mmoe.py +2 -24
  18. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/multi_task/ple.py +13 -35
  19. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/multi_task/poso.py +4 -28
  20. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/multi_task/share_bottom.py +1 -24
  21. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/afm.py +3 -27
  22. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/autoint.py +5 -38
  23. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/dcn.py +1 -26
  24. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/dcn_v2.py +5 -33
  25. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/deepfm.py +2 -29
  26. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/dien.py +2 -28
  27. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/din.py +2 -27
  28. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/eulernet.py +3 -30
  29. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/ffm.py +0 -26
  30. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/fibinet.py +8 -32
  31. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/fm.py +0 -29
  32. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/lr.py +0 -30
  33. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/masknet.py +4 -30
  34. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/pnn.py +4 -28
  35. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/widedeep.py +0 -32
  36. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/xdeepfm.py +0 -30
  37. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/retrieval/dssm.py +0 -24
  38. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/retrieval/dssm_v2.py +0 -24
  39. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/retrieval/mind.py +0 -20
  40. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/retrieval/sdm.py +0 -20
  41. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/retrieval/youtube_dnn.py +0 -21
  42. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/sequential/hstu.py +0 -18
  43. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/utils/model.py +79 -1
  44. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/utils/types.py +35 -0
  45. {nextrec-0.4.22 → nextrec-0.4.23}/pyproject.toml +1 -1
  46. {nextrec-0.4.22 → nextrec-0.4.23}/test/test_ranking_models.py +0 -3
  47. {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/example_multitask.py +1 -8
  48. {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/example_ranking_din.py +3 -5
  49. {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/notebooks/en/Hands on nextrec.ipynb +1 -1
  50. {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/notebooks/zh//345/277/253/351/200/237/345/205/245/351/227/250nextrec.ipynb +1 -1
  51. nextrec-0.4.22/nextrec/__version__.py +0 -1
  52. {nextrec-0.4.22 → nextrec-0.4.23}/.github/workflows/publish.yml +0 -0
  53. {nextrec-0.4.22 → nextrec-0.4.23}/.github/workflows/tests.yml +0 -0
  54. {nextrec-0.4.22 → nextrec-0.4.23}/.gitignore +0 -0
  55. {nextrec-0.4.22 → nextrec-0.4.23}/.readthedocs.yaml +0 -0
  56. {nextrec-0.4.22 → nextrec-0.4.23}/CODE_OF_CONDUCT.md +0 -0
  57. {nextrec-0.4.22 → nextrec-0.4.23}/CONTRIBUTING.md +0 -0
  58. {nextrec-0.4.22 → nextrec-0.4.23}/LICENSE +0 -0
  59. {nextrec-0.4.22 → nextrec-0.4.23}/MANIFEST.in +0 -0
  60. {nextrec-0.4.22 → nextrec-0.4.23}/assets/Feature Configuration.png +0 -0
  61. {nextrec-0.4.22 → nextrec-0.4.23}/assets/Model Parameters.png +0 -0
  62. {nextrec-0.4.22 → nextrec-0.4.23}/assets/Training Configuration.png +0 -0
  63. {nextrec-0.4.22 → nextrec-0.4.23}/assets/Training logs.png +0 -0
  64. {nextrec-0.4.22 → nextrec-0.4.23}/assets/logo.png +0 -0
  65. {nextrec-0.4.22 → nextrec-0.4.23}/assets/mmoe_tutorial.png +0 -0
  66. {nextrec-0.4.22 → nextrec-0.4.23}/assets/nextrec_diagram.png +0 -0
  67. {nextrec-0.4.22 → nextrec-0.4.23}/assets/test data.png +0 -0
  68. {nextrec-0.4.22 → nextrec-0.4.23}/dataset/ctcvr_task.csv +0 -0
  69. {nextrec-0.4.22 → nextrec-0.4.23}/dataset/ecommerce_task.csv +0 -0
  70. {nextrec-0.4.22 → nextrec-0.4.23}/dataset/match_task.csv +0 -0
  71. {nextrec-0.4.22 → nextrec-0.4.23}/dataset/movielens_100k.csv +0 -0
  72. {nextrec-0.4.22 → nextrec-0.4.23}/dataset/multitask_task.csv +0 -0
  73. {nextrec-0.4.22 → nextrec-0.4.23}/dataset/ranking_task.csv +0 -0
  74. {nextrec-0.4.22 → nextrec-0.4.23}/docs/rtd/Makefile +0 -0
  75. {nextrec-0.4.22 → nextrec-0.4.23}/docs/rtd/make.bat +0 -0
  76. {nextrec-0.4.22 → nextrec-0.4.23}/docs/rtd/modules.rst +0 -0
  77. {nextrec-0.4.22 → nextrec-0.4.23}/docs/rtd/nextrec.basic.rst +0 -0
  78. {nextrec-0.4.22 → nextrec-0.4.23}/docs/rtd/nextrec.data.rst +0 -0
  79. {nextrec-0.4.22 → nextrec-0.4.23}/docs/rtd/nextrec.loss.rst +0 -0
  80. {nextrec-0.4.22 → nextrec-0.4.23}/docs/rtd/nextrec.rst +0 -0
  81. {nextrec-0.4.22 → nextrec-0.4.23}/docs/rtd/nextrec.utils.rst +0 -0
  82. {nextrec-0.4.22 → nextrec-0.4.23}/docs/rtd/requirements.txt +0 -0
  83. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/__init__.py +0 -0
  84. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/basic/__init__.py +0 -0
  85. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/basic/activation.py +0 -0
  86. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/basic/callback.py +0 -0
  87. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/basic/features.py +0 -0
  88. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/basic/heads.py +0 -0
  89. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/basic/layers.py +0 -0
  90. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/basic/loggers.py +0 -0
  91. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/basic/session.py +0 -0
  92. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/cli.py +0 -0
  93. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/data/__init__.py +0 -0
  94. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/data/batch_utils.py +0 -0
  95. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/data/data_processing.py +0 -0
  96. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/data/data_utils.py +0 -0
  97. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/data/dataloader.py +0 -0
  98. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/loss/__init__.py +0 -0
  99. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/loss/grad_norm.py +0 -0
  100. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/generative/__init__.py +0 -0
  101. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/generative/tiger.py +0 -0
  102. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/multi_task/__init__.py +0 -0
  103. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/__init__.py +0 -0
  104. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/representation/__init__.py +0 -0
  105. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/representation/autorec.py +0 -0
  106. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/representation/bpr.py +0 -0
  107. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/representation/cl4srec.py +0 -0
  108. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/representation/lightgcn.py +0 -0
  109. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/representation/mf.py +0 -0
  110. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/representation/rqvae.py +0 -0
  111. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/representation/s3rec.py +0 -0
  112. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/retrieval/__init__.py +0 -0
  113. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/sequential/sasrec.py +0 -0
  114. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/utils/__init__.py +0 -0
  115. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/utils/config.py +0 -0
  116. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/utils/console.py +0 -0
  117. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/utils/data.py +0 -0
  118. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/utils/embedding.py +0 -0
  119. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/utils/feature.py +0 -0
  120. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/utils/loss.py +0 -0
  121. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/utils/torch_utils.py +0 -0
  122. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/NextRec-CLI.md +0 -0
  123. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/NextRec-CLI_zh.md +0 -0
  124. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/feature_config.yaml +0 -0
  125. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/afm.yaml +0 -0
  126. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/autoint.yaml +0 -0
  127. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/dcn.yaml +0 -0
  128. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/deepfm.yaml +0 -0
  129. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/din.yaml +0 -0
  130. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/esmm.yaml +0 -0
  131. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/fibinet.yaml +0 -0
  132. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/fm.yaml +0 -0
  133. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/masknet.yaml +0 -0
  134. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/mmoe.yaml +0 -0
  135. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/ple.yaml +0 -0
  136. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/pnn.yaml +0 -0
  137. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/poso.yaml +0 -0
  138. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/share_bottom.yaml +0 -0
  139. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/widedeep.yaml +0 -0
  140. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/xdeepfm.yaml +0 -0
  141. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/predict_config.yaml +0 -0
  142. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/predict_config_template.yaml +0 -0
  143. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/train_config.yaml +0 -0
  144. {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/train_config_template.yaml +0 -0
  145. {nextrec-0.4.22 → nextrec-0.4.23}/pytest.ini +0 -0
  146. {nextrec-0.4.22 → nextrec-0.4.23}/requirements.txt +0 -0
  147. {nextrec-0.4.22 → nextrec-0.4.23}/scripts/format_code.py +0 -0
  148. {nextrec-0.4.22 → nextrec-0.4.23}/test/__init__.py +0 -0
  149. {nextrec-0.4.22 → nextrec-0.4.23}/test/conftest.py +0 -0
  150. {nextrec-0.4.22 → nextrec-0.4.23}/test/helpers.py +0 -0
  151. {nextrec-0.4.22 → nextrec-0.4.23}/test/run_tests.py +0 -0
  152. {nextrec-0.4.22 → nextrec-0.4.23}/test/test_base_model_regularization.py +0 -0
  153. {nextrec-0.4.22 → nextrec-0.4.23}/test/test_generative_models.py +0 -0
  154. {nextrec-0.4.22 → nextrec-0.4.23}/test/test_layers.py +0 -0
  155. {nextrec-0.4.22 → nextrec-0.4.23}/test/test_losses.py +0 -0
  156. {nextrec-0.4.22 → nextrec-0.4.23}/test/test_match_models.py +0 -0
  157. {nextrec-0.4.22 → nextrec-0.4.23}/test/test_multitask_models.py +0 -0
  158. {nextrec-0.4.22 → nextrec-0.4.23}/test/test_preprocessor.py +0 -0
  159. {nextrec-0.4.22 → nextrec-0.4.23}/test/test_utils_console.py +0 -0
  160. {nextrec-0.4.22 → nextrec-0.4.23}/test/test_utils_data.py +0 -0
  161. {nextrec-0.4.22 → nextrec-0.4.23}/test/test_utils_embedding.py +0 -0
  162. {nextrec-0.4.22 → nextrec-0.4.23}/test_requirements.txt +0 -0
  163. {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/distributed/example_distributed_training.py +0 -0
  164. {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/distributed/example_distributed_training_large_dataset.py +0 -0
  165. {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/example_match.py +0 -0
  166. {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/movielen_match_dssm.py +0 -0
  167. {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/movielen_ranking_deepfm.py +0 -0
  168. {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/notebooks/en/Build semantic ID with RQ-VAE.ipynb +0 -0
  169. {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/notebooks/en/Hands on dataprocessor.ipynb +0 -0
  170. {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/notebooks/zh//344/275/277/347/224/250RQ-VAE/346/236/204/345/273/272/350/257/255/344/271/211ID.ipynb" +0 -0
  171. {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/notebooks/zh//345/246/202/344/275/225/344/275/277/347/224/250DataProcessor/350/277/233/350/241/214/351/242/204/345/244/204/347/220/206.ipynb" +0 -0
  172. {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/run_all_match_models.py +0 -0
  173. {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/run_all_multitask_models.py +0 -0
  174. {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/run_all_ranking_models.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nextrec
3
- Version: 0.4.22
3
+ Version: 0.4.23
4
4
  Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
5
5
  Project-URL: Homepage, https://github.com/zerolovesea/NextRec
6
6
  Project-URL: Repository, https://github.com/zerolovesea/NextRec
@@ -69,7 +69,7 @@ Description-Content-Type: text/markdown
69
69
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
70
70
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
71
71
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
72
- ![Version](https://img.shields.io/badge/Version-0.4.22-orange.svg)
72
+ ![Version](https://img.shields.io/badge/Version-0.4.23-orange.svg)
73
73
  [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/zerolovesea/NextRec)
74
74
 
75
75
  中文文档 | [English Version](README_en.md)
@@ -191,6 +191,8 @@ model = DIN(
191
191
  dense_features=dense_features,
192
192
  sparse_features=sparse_features,
193
193
  sequence_features=sequence_features,
194
+ behavior_feature_name="sequence_0",
195
+ candidate_feature_name="item_id",
194
196
  mlp_params=mlp_params,
195
197
  attention_hidden_units=[80, 40],
196
198
  attention_activation='sigmoid',
@@ -204,7 +206,7 @@ model = DIN(
204
206
  session_id="din_tutorial", # 实验id,用于存放训练日志
205
207
  )
206
208
 
207
- # 编译模型,设置优化器和损失函数
209
+ # 编译模型,优化器/损失/学习率调度器统一在 compile 中设置
208
210
  model.compile(
209
211
  optimizer = "adam",
210
212
  optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5},
@@ -247,11 +249,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
247
249
 
248
250
  预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
249
251
 
250
- > 截止当前版本0.4.22,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
252
+ > 截止当前版本0.4.23,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
251
253
 
252
254
  ## 兼容平台
253
255
 
254
- 当前最新版本为0.4.22,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
256
+ 当前最新版本为0.4.23,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
255
257
 
256
258
  | 平台 | 配置 |
257
259
  |------|------|
@@ -8,7 +8,7 @@
8
8
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
9
9
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
10
10
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
11
- ![Version](https://img.shields.io/badge/Version-0.4.22-orange.svg)
11
+ ![Version](https://img.shields.io/badge/Version-0.4.23-orange.svg)
12
12
  [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/zerolovesea/NextRec)
13
13
 
14
14
  中文文档 | [English Version](README_en.md)
@@ -130,6 +130,8 @@ model = DIN(
130
130
  dense_features=dense_features,
131
131
  sparse_features=sparse_features,
132
132
  sequence_features=sequence_features,
133
+ behavior_feature_name="sequence_0",
134
+ candidate_feature_name="item_id",
133
135
  mlp_params=mlp_params,
134
136
  attention_hidden_units=[80, 40],
135
137
  attention_activation='sigmoid',
@@ -143,7 +145,7 @@ model = DIN(
143
145
  session_id="din_tutorial", # 实验id,用于存放训练日志
144
146
  )
145
147
 
146
- # 编译模型,设置优化器和损失函数
148
+ # 编译模型,优化器/损失/学习率调度器统一在 compile 中设置
147
149
  model.compile(
148
150
  optimizer = "adam",
149
151
  optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5},
@@ -186,11 +188,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
186
188
 
187
189
  预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
188
190
 
189
- > 截止当前版本0.4.22,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
191
+ > 截止当前版本0.4.23,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
190
192
 
191
193
  ## 兼容平台
192
194
 
193
- 当前最新版本为0.4.22,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
195
+ 当前最新版本为0.4.23,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
194
196
 
195
197
  | 平台 | 配置 |
196
198
  |------|------|
@@ -8,7 +8,7 @@
8
8
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
9
9
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
10
10
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
11
- ![Version](https://img.shields.io/badge/Version-0.4.22-orange.svg)
11
+ ![Version](https://img.shields.io/badge/Version-0.4.23-orange.svg)
12
12
  [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/zerolovesea/NextRec)
13
13
 
14
14
  English | [中文文档](README.md)
@@ -148,7 +148,7 @@ model = DIN(
148
148
  session_id="din_tutorial", # experiment id for logs
149
149
  )
150
150
 
151
- # Compile model with optimizer and loss
151
+ # Compile model; configure optimizer/loss/scheduler via compile()
152
152
  model.compile(
153
153
  optimizer = "adam",
154
154
  optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5},
@@ -191,11 +191,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
191
191
 
192
192
  Prediction outputs are saved under `{checkpoint_path}/predictions/{name}.{save_data_format}`.
193
193
 
194
- > As of version 0.4.22, NextRec CLI supports single-machine training; distributed training features are currently under development.
194
+ > As of version 0.4.23, NextRec CLI supports single-machine training; distributed training features are currently under development.
195
195
 
196
196
  ## Platform Compatibility
197
197
 
198
- The current version is 0.4.22. All models and test code have been validated on the following platforms. If you encounter compatibility issues, please report them in the issue tracker with your system version:
198
+ The current version is 0.4.23. All models and test code have been validated on the following platforms. If you encounter compatibility issues, please report them in the issue tracker with your system version:
199
199
 
200
200
  | Platform | Configuration |
201
201
  |----------|---------------|
@@ -55,6 +55,7 @@ model = DeepFM(
55
55
  session_id="movielens_deepfm", # manages logs and checkpoints
56
56
  )
57
57
 
58
+ # Optimizer/loss/scheduler are configured via compile()
58
59
  model.compile(
59
60
  optimizer="adam",
60
61
  optimizer_params={"lr": 1e-3, "weight_decay": 1e-5},
@@ -11,7 +11,7 @@ sys.path.insert(0, str(PROJECT_ROOT / "nextrec"))
11
11
  project = "NextRec"
12
12
  copyright = "2025, Yang Zhou"
13
13
  author = "Yang Zhou"
14
- release = "0.4.22"
14
+ release = "0.4.23"
15
15
 
16
16
  extensions = [
17
17
  "myst_parser",
@@ -57,6 +57,7 @@ model = DeepFM(
57
57
  session_id="deepfm_demo",
58
58
  )
59
59
 
60
+ # Configure optimizer/loss/scheduler via compile()
60
61
  model.compile(
61
62
  optimizer="adam",
62
63
  optimizer_params={"lr": 1e-3, "weight_decay": 1e-5},
@@ -55,6 +55,7 @@ model = DeepFM(
55
55
  session_id="movielens_deepfm", # 管理实验日志与检查点
56
56
  )
57
57
 
58
+ # 优化器/损失/学习率调度器统一在 compile 中设置
58
59
  model.compile(
59
60
  optimizer="adam",
60
61
  optimizer_params={"lr": 1e-3, "weight_decay": 1e-5},
@@ -0,0 +1 @@
1
+ __version__ = "0.4.23"
@@ -2,7 +2,7 @@
2
2
  Metrics computation and configuration for model evaluation.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 20/12/2025
5
+ Checkpoint: edit on 29/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
@@ -39,7 +39,6 @@ REGRESSION_METRICS = {"mse", "mae", "rmse", "r2", "mape", "msle"}
39
39
  TASK_DEFAULT_METRICS = {
40
40
  "binary": ["auc", "gauc", "ks", "logloss", "accuracy", "precision", "recall", "f1"],
41
41
  "regression": ["mse", "mae", "rmse", "r2", "mape"],
42
- "multilabel": ["auc", "hamming_loss", "subset_accuracy", "micro_f1", "macro_f1"],
43
42
  "matching": ["auc", "gauc", "precision@10", "hitrate@10", "map@10", "cosine"]
44
43
  + [f"recall@{k}" for k in (5, 10, 20)]
45
44
  + [f"ndcg@{k}" for k in (5, 10, 20)]
@@ -2,13 +2,14 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 28/12/2025
5
+ Checkpoint: edit on 29/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
9
9
  import getpass
10
10
  import logging
11
11
  import os
12
+ import sys
12
13
  import pickle
13
14
  import socket
14
15
  from pathlib import Path
@@ -16,6 +17,16 @@ from typing import Any, Literal
16
17
 
17
18
  import numpy as np
18
19
  import pandas as pd
20
+
21
+ try:
22
+ import swanlab # type: ignore
23
+ except ModuleNotFoundError:
24
+ swanlab = None
25
+ try:
26
+ import wandb # type: ignore
27
+ except ModuleNotFoundError:
28
+ wandb = None
29
+
19
30
  import torch
20
31
  import torch.distributed as dist
21
32
  import torch.nn as nn
@@ -74,13 +85,19 @@ from nextrec.utils.torch_utils import (
74
85
  to_tensor,
75
86
  )
76
87
  from nextrec.utils.config import safe_value
77
- from nextrec.utils.model import compute_ranking_loss
88
+ from nextrec.utils.model import (
89
+ compute_ranking_loss,
90
+ get_loss_list,
91
+ resolve_loss_weights,
92
+ get_training_modes,
93
+ )
78
94
  from nextrec.utils.types import (
79
95
  LossName,
80
96
  OptimizerName,
81
97
  SchedulerName,
82
98
  TrainingModeName,
83
99
  TaskTypeName,
100
+ MetricsName,
84
101
  )
85
102
 
86
103
 
@@ -90,7 +107,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
90
107
  raise NotImplementedError
91
108
 
92
109
  @property
93
- def default_task(self) -> str | list[str]:
110
+ def default_task(self) -> TaskTypeName | list[TaskTypeName]:
94
111
  raise NotImplementedError
95
112
 
96
113
  def __init__(
@@ -139,6 +156,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
139
156
  world_size: Number of processes (defaults to env WORLD_SIZE).
140
157
  local_rank: Local rank for selecting CUDA device (defaults to env LOCAL_RANK).
141
158
  ddp_find_unused_parameters: Default False, set it True only when exist unused parameters in ddp model, in most cases should be False.
159
+
160
+ Note:
161
+ Optimizer, scheduler, and loss are configured via compile().
142
162
  """
143
163
  super(BaseModel, self).__init__()
144
164
 
@@ -171,24 +191,12 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
171
191
  dense_features, sparse_features, sequence_features, target, id_columns
172
192
  )
173
193
 
174
- self.task = self.default_task if task is None else task
194
+ self.task = task or self.default_task
175
195
  self.nums_task = len(self.task) if isinstance(self.task, list) else 1
176
- if isinstance(training_mode, list):
177
- training_modes = list(training_mode)
178
- if len(training_modes) != self.nums_task:
179
- raise ValueError(
180
- "[BaseModel-init Error] training_mode list length must match number of tasks."
181
- )
182
- else:
183
- training_modes = [training_mode] * self.nums_task
184
- if any(
185
- mode not in {"pointwise", "pairwise", "listwise"} for mode in training_modes
186
- ):
187
- raise ValueError(
188
- "[BaseModel-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
189
- )
190
- self.training_modes = training_modes
191
- self.training_mode = training_modes if self.nums_task > 1 else training_modes[0]
196
+ self.training_modes = get_training_modes(training_mode, self.nums_task)
197
+ self.training_mode = (
198
+ self.training_modes if self.nums_task > 1 else self.training_modes[0]
199
+ )
192
200
 
193
201
  self.embedding_l1_reg = embedding_l1_reg
194
202
  self.dense_l1_reg = dense_l1_reg
@@ -196,8 +204,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
196
204
  self.dense_l2_reg = dense_l2_reg
197
205
  self.regularization_weights = []
198
206
  self.embedding_params = []
199
- self.loss_weight = None
207
+
200
208
  self.ignore_label = None
209
+ self.compiled = False
201
210
 
202
211
  self.max_gradient_norm = 1.0
203
212
  self.logger_initialized = False
@@ -431,28 +440,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
431
440
  "pairwise": "bpr",
432
441
  "listwise": "listnet",
433
442
  }
434
- effective_loss = loss
435
- if effective_loss is None:
436
- loss_list = [default_losses[mode] for mode in self.training_modes]
437
- elif isinstance(effective_loss, list):
438
- if not effective_loss:
439
- loss_list = [default_losses[mode] for mode in self.training_modes]
440
- else:
441
- if len(effective_loss) != self.nums_task:
442
- raise ValueError(
443
- f"[BaseModel-compile Error] Number of loss functions ({len(effective_loss)}) must match number of tasks ({self.nums_task})."
444
- )
445
- loss_list = list(effective_loss)
446
- else:
447
- loss_list = [effective_loss] * self.nums_task
448
-
449
- for idx, mode in enumerate(self.training_modes):
450
- if isinstance(loss_list[idx], str) and loss_list[idx] in {
451
- "bce",
452
- "binary_crossentropy",
453
- }:
454
- if mode in {"pairwise", "listwise"}:
455
- loss_list[idx] = default_losses[mode]
443
+ loss_list = get_loss_list(
444
+ loss, self.training_modes, self.nums_task, default_losses
445
+ )
456
446
  self.loss_params = loss_params or {}
457
447
  optimizer_params = optimizer_params or {}
458
448
  self.optimizer_name = (
@@ -516,30 +506,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
516
506
  nums_task=self.nums_task, device=self.device, **grad_norm_params
517
507
  )
518
508
  self.loss_weights = None
519
- elif loss_weights is None:
520
- self.loss_weights = None
521
- elif self.nums_task == 1:
522
- if isinstance(loss_weights, (list, tuple)):
523
- if len(loss_weights) != 1:
524
- raise ValueError(
525
- "[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup."
526
- )
527
- loss_weights = loss_weights[0]
528
- self.loss_weights = [float(loss_weights)] # type: ignore
529
509
  else:
530
- if isinstance(loss_weights, (int, float)):
531
- weights = [float(loss_weights)] * self.nums_task
532
- elif isinstance(loss_weights, (list, tuple)):
533
- weights = [float(w) for w in loss_weights]
534
- if len(weights) != self.nums_task:
535
- raise ValueError(
536
- f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task})."
537
- )
538
- else:
539
- raise TypeError(
540
- f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}"
541
- )
542
- self.loss_weights = weights
510
+ self.loss_weights = resolve_loss_weights(loss_weights, self.nums_task)
511
+ self.compiled = True
543
512
 
544
513
  def compute_loss(self, y_pred, y_true):
545
514
  if y_true is None:
@@ -602,9 +571,6 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
602
571
  for i, (start, end) in enumerate(slices): # type: ignore
603
572
  y_pred_i = y_pred[:, start:end]
604
573
  y_true_i = y_true[:, start:end]
605
- total_count = y_true_i.shape[0]
606
- # valid_count = None
607
-
608
574
  # mask ignored labels
609
575
  if self.ignore_label is not None:
610
576
  valid_mask = y_true_i != self.ignore_label
@@ -613,11 +579,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
613
579
  if not torch.any(valid_mask):
614
580
  task_losses.append(y_pred_i.sum() * 0.0)
615
581
  continue
616
- # valid_count = valid_mask.sum().to(dtype=y_true_i.dtype)
617
582
  y_pred_i = y_pred_i[valid_mask]
618
583
  y_true_i = y_true_i[valid_mask]
619
- # else:
620
- # valid_count = y_true_i.new_tensor(float(total_count))
621
584
 
622
585
  mode = self.training_modes[i]
623
586
 
@@ -691,7 +654,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
691
654
  train_data=None,
692
655
  valid_data=None,
693
656
  metrics: (
694
- list[str] | dict[str, list[str]] | None
657
+ list[MetricsName] | dict[str, list[MetricsName]] | None
695
658
  ) = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
696
659
  epochs: int = 1,
697
660
  shuffle: bool = True,
@@ -705,6 +668,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
705
668
  use_tensorboard: bool = True,
706
669
  use_wandb: bool = False,
707
670
  use_swanlab: bool = False,
671
+ wandb_api: str | None = None,
672
+ swanlab_api: str | None = None,
708
673
  wandb_kwargs: dict | None = None,
709
674
  swanlab_kwargs: dict | None = None,
710
675
  auto_ddp_sampler: bool = True,
@@ -734,6 +699,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
734
699
  use_tensorboard: Enable tensorboard logging.
735
700
  use_wandb: Enable Weights & Biases logging.
736
701
  use_swanlab: Enable SwanLab logging.
702
+ wandb_api: W&B API key for non-tty login.
703
+ swanlab_api: SwanLab API key for non-tty login.
737
704
  wandb_kwargs: Optional kwargs for wandb.init(...).
738
705
  swanlab_kwargs: Optional kwargs for swanlab.init(...).
739
706
  auto_ddp_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
@@ -751,6 +718,16 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
751
718
  )
752
719
  self.to(self.device)
753
720
 
721
+ if not self.compiled:
722
+ self.compile(
723
+ optimizer="adam",
724
+ optimizer_params={},
725
+ scheduler=None,
726
+ scheduler_params={},
727
+ loss=None,
728
+ loss_params={},
729
+ )
730
+
754
731
  if (
755
732
  self.distributed
756
733
  and dist.is_available()
@@ -825,6 +802,24 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
825
802
  }
826
803
  training_config: dict = safe_value(training_config) # type: ignore
827
804
 
805
+ if self.is_main_process:
806
+ is_tty = sys.stdin.isatty() and sys.stdout.isatty()
807
+ if not is_tty:
808
+ if use_wandb and wandb_api:
809
+ if wandb is None:
810
+ logging.warning(
811
+ "[BaseModel-fit] wandb not installed, skip wandb login."
812
+ )
813
+ else:
814
+ wandb.login(key=wandb_api)
815
+ if use_swanlab and swanlab_api:
816
+ if swanlab is None:
817
+ logging.warning(
818
+ "[BaseModel-fit] swanlab not installed, skip swanlab login."
819
+ )
820
+ else:
821
+ swanlab.login(api_key=swanlab_api)
822
+
828
823
  self.training_logger = (
829
824
  TrainingLogger(
830
825
  session=self.session,
@@ -1,5 +1,9 @@
1
1
  """
2
2
  Summary utilities for BaseModel.
3
+
4
+ Date: create on 03/12/2025
5
+ Checkpoint: edit on 29/12/2025
6
+ Author: Yang Zhou,zyaztec@gmail.com
3
7
  """
4
8
 
5
9
  from __future__ import annotations
@@ -12,9 +16,39 @@ from torch.utils.data import DataLoader
12
16
 
13
17
  from nextrec.basic.loggers import colorize, format_kv
14
18
  from nextrec.data.data_processing import extract_label_arrays, get_data_length
19
+ from nextrec.utils.types import TaskTypeName
15
20
 
16
21
 
17
22
  class SummarySet:
23
+ model_name: str
24
+ dense_features: list[Any]
25
+ sparse_features: list[Any]
26
+ sequence_features: list[Any]
27
+ task: TaskTypeName | list[TaskTypeName]
28
+ target_columns: list[str]
29
+ nums_task: int
30
+ metrics: Any
31
+ device: Any
32
+ optimizer_name: str
33
+ optimizer_params: dict[str, Any]
34
+ scheduler_name: str | None
35
+ scheduler_params: dict[str, Any]
36
+ loss_config: Any
37
+ loss_weights: Any
38
+ grad_norm: Any
39
+ embedding_l1_reg: float
40
+ embedding_l2_reg: float
41
+ dense_l1_reg: float
42
+ dense_l2_reg: float
43
+ early_stop_patience: int
44
+ max_gradient_norm: float | None
45
+ metrics_sample_limit: int | None
46
+ session_id: str | None
47
+ features_config_path: str
48
+ checkpoint_path: str
49
+ train_data_summary: dict[str, Any] | None
50
+ valid_data_summary: dict[str, Any] | None
51
+
18
52
  def build_data_summary(
19
53
  self, data: Any, data_loader: DataLoader | None, sample_key: str
20
54
  ):
@@ -305,7 +339,7 @@ class SummarySet:
305
339
  lines = details.get("lines", [])
306
340
  logger.info(f"{target_name}:")
307
341
  for label, value in lines:
308
- logger.info(format_kv(label, value))
342
+ logger.info(f" {format_kv(label, value)}")
309
343
 
310
344
  if self.valid_data_summary:
311
345
  if self.train_data_summary:
@@ -320,4 +354,4 @@ class SummarySet:
320
354
  lines = details.get("lines", [])
321
355
  logger.info(f"{target_name}:")
322
356
  for label, value in lines:
323
- logger.info(format_kv(label, value))
357
+ logger.info(f" {format_kv(label, value)}")