nextrec 0.4.24__tar.gz → 0.4.25__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. {nextrec-0.4.24 → nextrec-0.4.25}/PKG-INFO +4 -4
  2. {nextrec-0.4.24 → nextrec-0.4.25}/README.md +3 -3
  3. {nextrec-0.4.24 → nextrec-0.4.25}/README_en.md +3 -3
  4. {nextrec-0.4.24 → nextrec-0.4.25}/docs/rtd/conf.py +1 -1
  5. nextrec-0.4.25/nextrec/__version__.py +1 -0
  6. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/basic/model.py +175 -58
  7. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/basic/summary.py +58 -0
  8. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/cli.py +13 -0
  9. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/data/data_processing.py +3 -9
  10. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/data/dataloader.py +25 -2
  11. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/data/preprocessor.py +283 -36
  12. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/utils/config.py +2 -0
  13. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/utils/model.py +14 -70
  14. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/utils/torch_utils.py +11 -0
  15. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec_cli_preset/NextRec-CLI.md +17 -0
  16. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec_cli_preset/NextRec-CLI_zh.md +17 -0
  17. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec_cli_preset/feature_config.yaml +8 -8
  18. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec_cli_preset/train_config.yaml +11 -1
  19. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec_cli_preset/train_config_template.yaml +12 -0
  20. {nextrec-0.4.24 → nextrec-0.4.25}/pyproject.toml +1 -1
  21. {nextrec-0.4.24 → nextrec-0.4.25}/tutorials/movielen_ranking_deepfm.py +1 -1
  22. nextrec-0.4.24/nextrec/__version__.py +0 -1
  23. {nextrec-0.4.24 → nextrec-0.4.25}/.github/workflows/publish.yml +0 -0
  24. {nextrec-0.4.24 → nextrec-0.4.25}/.github/workflows/tests.yml +0 -0
  25. {nextrec-0.4.24 → nextrec-0.4.25}/.gitignore +0 -0
  26. {nextrec-0.4.24 → nextrec-0.4.25}/.readthedocs.yaml +0 -0
  27. {nextrec-0.4.24 → nextrec-0.4.25}/CODE_OF_CONDUCT.md +0 -0
  28. {nextrec-0.4.24 → nextrec-0.4.25}/CONTRIBUTING.md +0 -0
  29. {nextrec-0.4.24 → nextrec-0.4.25}/LICENSE +0 -0
  30. {nextrec-0.4.24 → nextrec-0.4.25}/MANIFEST.in +0 -0
  31. {nextrec-0.4.24 → nextrec-0.4.25}/assets/Feature Configuration.png +0 -0
  32. {nextrec-0.4.24 → nextrec-0.4.25}/assets/Model Parameters.png +0 -0
  33. {nextrec-0.4.24 → nextrec-0.4.25}/assets/Training Configuration.png +0 -0
  34. {nextrec-0.4.24 → nextrec-0.4.25}/assets/Training logs.png +0 -0
  35. {nextrec-0.4.24 → nextrec-0.4.25}/assets/logo.png +0 -0
  36. {nextrec-0.4.24 → nextrec-0.4.25}/assets/mmoe_tutorial.png +0 -0
  37. {nextrec-0.4.24 → nextrec-0.4.25}/assets/nextrec_diagram.png +0 -0
  38. {nextrec-0.4.24 → nextrec-0.4.25}/assets/test data.png +0 -0
  39. {nextrec-0.4.24 → nextrec-0.4.25}/dataset/ctcvr_task.csv +0 -0
  40. {nextrec-0.4.24 → nextrec-0.4.25}/dataset/ecommerce_task.csv +0 -0
  41. {nextrec-0.4.24 → nextrec-0.4.25}/dataset/match_task.csv +0 -0
  42. {nextrec-0.4.24 → nextrec-0.4.25}/dataset/movielens_100k.csv +0 -0
  43. {nextrec-0.4.24 → nextrec-0.4.25}/dataset/multitask_task.csv +0 -0
  44. {nextrec-0.4.24 → nextrec-0.4.25}/dataset/ranking_task.csv +0 -0
  45. {nextrec-0.4.24 → nextrec-0.4.25}/docs/en/Getting started guide.md +0 -0
  46. {nextrec-0.4.24 → nextrec-0.4.25}/docs/rtd/Makefile +0 -0
  47. {nextrec-0.4.24 → nextrec-0.4.25}/docs/rtd/index.md +0 -0
  48. {nextrec-0.4.24 → nextrec-0.4.25}/docs/rtd/make.bat +0 -0
  49. {nextrec-0.4.24 → nextrec-0.4.25}/docs/rtd/modules.rst +0 -0
  50. {nextrec-0.4.24 → nextrec-0.4.25}/docs/rtd/nextrec.basic.rst +0 -0
  51. {nextrec-0.4.24 → nextrec-0.4.25}/docs/rtd/nextrec.data.rst +0 -0
  52. {nextrec-0.4.24 → nextrec-0.4.25}/docs/rtd/nextrec.loss.rst +0 -0
  53. {nextrec-0.4.24 → nextrec-0.4.25}/docs/rtd/nextrec.rst +0 -0
  54. {nextrec-0.4.24 → nextrec-0.4.25}/docs/rtd/nextrec.utils.rst +0 -0
  55. {nextrec-0.4.24 → nextrec-0.4.25}/docs/rtd/requirements.txt +0 -0
  56. {nextrec-0.4.24 → nextrec-0.4.25}/docs/zh//345/277/253/351/200/237/344/270/212/346/211/213.md" +0 -0
  57. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/__init__.py +0 -0
  58. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/basic/__init__.py +0 -0
  59. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/basic/activation.py +0 -0
  60. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/basic/callback.py +0 -0
  61. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/basic/features.py +0 -0
  62. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/basic/heads.py +0 -0
  63. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/basic/layers.py +0 -0
  64. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/basic/loggers.py +0 -0
  65. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/basic/metrics.py +0 -0
  66. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/basic/session.py +0 -0
  67. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/data/__init__.py +0 -0
  68. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/data/batch_utils.py +0 -0
  69. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/data/data_utils.py +0 -0
  70. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/loss/__init__.py +0 -0
  71. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/loss/grad_norm.py +0 -0
  72. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/loss/listwise.py +0 -0
  73. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/loss/pairwise.py +0 -0
  74. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/loss/pointwise.py +0 -0
  75. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/generative/__init__.py +0 -0
  76. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/generative/tiger.py +0 -0
  77. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/multi_task/__init__.py +0 -0
  78. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/multi_task/aitm.py +0 -0
  79. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/multi_task/apg.py +0 -0
  80. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/multi_task/cross_stitch.py +0 -0
  81. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/multi_task/esmm.py +0 -0
  82. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/multi_task/mmoe.py +0 -0
  83. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/multi_task/pepnet.py +0 -0
  84. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/multi_task/ple.py +0 -0
  85. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/multi_task/poso.py +0 -0
  86. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/multi_task/share_bottom.py +0 -0
  87. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/multi_task/snr_trans.py +0 -0
  88. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/ranking/__init__.py +0 -0
  89. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/ranking/afm.py +0 -0
  90. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/ranking/autoint.py +0 -0
  91. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/ranking/dcn.py +0 -0
  92. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/ranking/dcn_v2.py +0 -0
  93. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/ranking/deepfm.py +0 -0
  94. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/ranking/dien.py +0 -0
  95. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/ranking/din.py +0 -0
  96. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/ranking/eulernet.py +0 -0
  97. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/ranking/ffm.py +0 -0
  98. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/ranking/fibinet.py +0 -0
  99. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/ranking/fm.py +0 -0
  100. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/ranking/lr.py +0 -0
  101. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/ranking/masknet.py +0 -0
  102. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/ranking/pnn.py +0 -0
  103. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/ranking/widedeep.py +0 -0
  104. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/ranking/xdeepfm.py +0 -0
  105. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/representation/__init__.py +0 -0
  106. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/representation/autorec.py +0 -0
  107. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/representation/bpr.py +0 -0
  108. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/representation/cl4srec.py +0 -0
  109. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/representation/lightgcn.py +0 -0
  110. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/representation/mf.py +0 -0
  111. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/representation/rqvae.py +0 -0
  112. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/representation/s3rec.py +0 -0
  113. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/retrieval/__init__.py +0 -0
  114. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/retrieval/dssm.py +0 -0
  115. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/retrieval/dssm_v2.py +0 -0
  116. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/retrieval/mind.py +0 -0
  117. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/retrieval/sdm.py +0 -0
  118. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/retrieval/youtube_dnn.py +0 -0
  119. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/sequential/hstu.py +0 -0
  120. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/models/sequential/sasrec.py +0 -0
  121. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/utils/__init__.py +0 -0
  122. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/utils/console.py +0 -0
  123. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/utils/data.py +0 -0
  124. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/utils/embedding.py +0 -0
  125. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/utils/feature.py +0 -0
  126. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/utils/loss.py +0 -0
  127. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec/utils/types.py +0 -0
  128. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec_cli_preset/model_configs/afm.yaml +0 -0
  129. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec_cli_preset/model_configs/autoint.yaml +0 -0
  130. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec_cli_preset/model_configs/dcn.yaml +0 -0
  131. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec_cli_preset/model_configs/deepfm.yaml +0 -0
  132. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec_cli_preset/model_configs/din.yaml +0 -0
  133. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec_cli_preset/model_configs/esmm.yaml +0 -0
  134. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec_cli_preset/model_configs/fibinet.yaml +0 -0
  135. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec_cli_preset/model_configs/fm.yaml +0 -0
  136. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec_cli_preset/model_configs/masknet.yaml +0 -0
  137. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec_cli_preset/model_configs/mmoe.yaml +0 -0
  138. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec_cli_preset/model_configs/ple.yaml +0 -0
  139. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec_cli_preset/model_configs/pnn.yaml +0 -0
  140. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec_cli_preset/model_configs/poso.yaml +0 -0
  141. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec_cli_preset/model_configs/share_bottom.yaml +0 -0
  142. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec_cli_preset/model_configs/widedeep.yaml +0 -0
  143. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec_cli_preset/model_configs/xdeepfm.yaml +0 -0
  144. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec_cli_preset/predict_config.yaml +0 -0
  145. {nextrec-0.4.24 → nextrec-0.4.25}/nextrec_cli_preset/predict_config_template.yaml +0 -0
  146. {nextrec-0.4.24 → nextrec-0.4.25}/pytest.ini +0 -0
  147. {nextrec-0.4.24 → nextrec-0.4.25}/requirements.txt +0 -0
  148. {nextrec-0.4.24 → nextrec-0.4.25}/scripts/format_code.py +0 -0
  149. {nextrec-0.4.24 → nextrec-0.4.25}/test/__init__.py +0 -0
  150. {nextrec-0.4.24 → nextrec-0.4.25}/test/conftest.py +0 -0
  151. {nextrec-0.4.24 → nextrec-0.4.25}/test/helpers.py +0 -0
  152. {nextrec-0.4.24 → nextrec-0.4.25}/test/run_tests.py +0 -0
  153. {nextrec-0.4.24 → nextrec-0.4.25}/test/test_base_model_regularization.py +0 -0
  154. {nextrec-0.4.24 → nextrec-0.4.25}/test/test_generative_models.py +0 -0
  155. {nextrec-0.4.24 → nextrec-0.4.25}/test/test_layers.py +0 -0
  156. {nextrec-0.4.24 → nextrec-0.4.25}/test/test_losses.py +0 -0
  157. {nextrec-0.4.24 → nextrec-0.4.25}/test/test_match_models.py +0 -0
  158. {nextrec-0.4.24 → nextrec-0.4.25}/test/test_multitask_models.py +0 -0
  159. {nextrec-0.4.24 → nextrec-0.4.25}/test/test_preprocessor.py +0 -0
  160. {nextrec-0.4.24 → nextrec-0.4.25}/test/test_ranking_models.py +0 -0
  161. {nextrec-0.4.24 → nextrec-0.4.25}/test/test_utils_console.py +0 -0
  162. {nextrec-0.4.24 → nextrec-0.4.25}/test/test_utils_data.py +0 -0
  163. {nextrec-0.4.24 → nextrec-0.4.25}/test/test_utils_embedding.py +0 -0
  164. {nextrec-0.4.24 → nextrec-0.4.25}/test_requirements.txt +0 -0
  165. {nextrec-0.4.24 → nextrec-0.4.25}/tutorials/distributed/example_distributed_training.py +0 -0
  166. {nextrec-0.4.24 → nextrec-0.4.25}/tutorials/distributed/example_distributed_training_large_dataset.py +0 -0
  167. {nextrec-0.4.24 → nextrec-0.4.25}/tutorials/example_match.py +0 -0
  168. {nextrec-0.4.24 → nextrec-0.4.25}/tutorials/example_multitask.py +0 -0
  169. {nextrec-0.4.24 → nextrec-0.4.25}/tutorials/example_ranking_din.py +0 -0
  170. {nextrec-0.4.24 → nextrec-0.4.25}/tutorials/movielen_match_dssm.py +0 -0
  171. {nextrec-0.4.24 → nextrec-0.4.25}/tutorials/notebooks/en/Build semantic ID with RQ-VAE.ipynb +0 -0
  172. {nextrec-0.4.24 → nextrec-0.4.25}/tutorials/notebooks/en/Hands on dataprocessor.ipynb +0 -0
  173. {nextrec-0.4.24 → nextrec-0.4.25}/tutorials/notebooks/en/Hands on nextrec.ipynb +0 -0
  174. {nextrec-0.4.24 → nextrec-0.4.25}/tutorials/notebooks/zh//344/275/277/347/224/250RQ-VAE/346/236/204/345/273/272/350/257/255/344/271/211ID.ipynb" +0 -0
  175. {nextrec-0.4.24 → nextrec-0.4.25}/tutorials/notebooks/zh//345/246/202/344/275/225/344/275/277/347/224/250DataProcessor/350/277/233/350/241/214/351/242/204/345/244/204/347/220/206.ipynb" +0 -0
  176. {nextrec-0.4.24 → nextrec-0.4.25}/tutorials/notebooks/zh//345/277/253/351/200/237/345/205/245/351/227/250nextrec.ipynb" +0 -0
  177. {nextrec-0.4.24 → nextrec-0.4.25}/tutorials/run_all_match_models.py +0 -0
  178. {nextrec-0.4.24 → nextrec-0.4.25}/tutorials/run_all_multitask_models.py +0 -0
  179. {nextrec-0.4.24 → nextrec-0.4.25}/tutorials/run_all_ranking_models.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nextrec
3
- Version: 0.4.24
3
+ Version: 0.4.25
4
4
  Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
5
5
  Project-URL: Homepage, https://github.com/zerolovesea/NextRec
6
6
  Project-URL: Repository, https://github.com/zerolovesea/NextRec
@@ -69,7 +69,7 @@ Description-Content-Type: text/markdown
69
69
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
70
70
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
71
71
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
72
- ![Version](https://img.shields.io/badge/Version-0.4.24-orange.svg)
72
+ ![Version](https://img.shields.io/badge/Version-0.4.25-orange.svg)
73
73
  [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/zerolovesea/NextRec)
74
74
 
75
75
  中文文档 | [English Version](README_en.md)
@@ -249,11 +249,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
249
249
 
250
250
  预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
251
251
 
252
- > 截止当前版本0.4.24,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
252
+ > 截止当前版本0.4.25,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
253
253
 
254
254
  ## 兼容平台
255
255
 
256
- 当前最新版本为0.4.24,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
256
+ 当前最新版本为0.4.25,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
257
257
 
258
258
  | 平台 | 配置 |
259
259
  |------|------|
@@ -8,7 +8,7 @@
8
8
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
9
9
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
10
10
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
11
- ![Version](https://img.shields.io/badge/Version-0.4.24-orange.svg)
11
+ ![Version](https://img.shields.io/badge/Version-0.4.25-orange.svg)
12
12
  [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/zerolovesea/NextRec)
13
13
 
14
14
  中文文档 | [English Version](README_en.md)
@@ -188,11 +188,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
188
188
 
189
189
  预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
190
190
 
191
- > 截止当前版本0.4.24,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
191
+ > 截止当前版本0.4.25,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
192
192
 
193
193
  ## 兼容平台
194
194
 
195
- 当前最新版本为0.4.24,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
195
+ 当前最新版本为0.4.25,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
196
196
 
197
197
  | 平台 | 配置 |
198
198
  |------|------|
@@ -8,7 +8,7 @@
8
8
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
9
9
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
10
10
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
11
- ![Version](https://img.shields.io/badge/Version-0.4.24-orange.svg)
11
+ ![Version](https://img.shields.io/badge/Version-0.4.25-orange.svg)
12
12
  [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/zerolovesea/NextRec)
13
13
 
14
14
  English | [中文文档](README.md)
@@ -191,11 +191,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
191
191
 
192
192
  Prediction outputs are saved under `{checkpoint_path}/predictions/{name}.{save_data_format}`.
193
193
 
194
- > As of version 0.4.24, NextRec CLI supports single-machine training; distributed training features are currently under development.
194
+ > As of version 0.4.25, NextRec CLI supports single-machine training; distributed training features are currently under development.
195
195
 
196
196
  ## Platform Compatibility
197
197
 
198
- The current version is 0.4.24. All models and test code have been validated on the following platforms. If you encounter compatibility issues, please report them in the issue tracker with your system version:
198
+ The current version is 0.4.25. All models and test code have been validated on the following platforms. If you encounter compatibility issues, please report them in the issue tracker with your system version:
199
199
 
200
200
  | Platform | Configuration |
201
201
  |----------|---------------|
@@ -11,7 +11,7 @@ sys.path.insert(0, str(PROJECT_ROOT / "nextrec"))
11
11
  project = "NextRec"
12
12
  copyright = "2025, Yang Zhou"
13
13
  author = "Yang Zhou"
14
- release = "0.4.24"
14
+ release = "0.4.25"
15
15
 
16
16
  extensions = [
17
17
  "myst_parser",
@@ -0,0 +1 @@
1
+ __version__ = "0.4.25"
@@ -2,7 +2,7 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 30/12/2025
5
+ Checkpoint: edit on 31/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
@@ -88,9 +88,8 @@ from nextrec.utils.config import safe_value
88
88
  from nextrec.utils.model import (
89
89
  compute_ranking_loss,
90
90
  get_loss_list,
91
- resolve_loss_weights,
92
- get_training_modes,
93
91
  )
92
+
94
93
  from nextrec.utils.types import (
95
94
  LossName,
96
95
  OptimizerName,
@@ -100,6 +99,7 @@ from nextrec.utils.types import (
100
99
  MetricsName,
101
100
  )
102
101
 
102
+ from nextrec.utils.data import FILE_FORMAT_CONFIG
103
103
 
104
104
  class BaseModel(SummarySet, FeatureSet, nn.Module):
105
105
  @property
@@ -110,6 +110,30 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
110
110
  def default_task(self) -> TaskTypeName | list[TaskTypeName]:
111
111
  raise NotImplementedError
112
112
 
113
+ @property
114
+ def training_mode(self) -> TrainingModeName | list[TrainingModeName]:
115
+ if self.nums_task > 1:
116
+ return self.training_modes
117
+ return self.training_modes[0] if self.training_modes else "pointwise"
118
+
119
+
120
+ @training_mode.setter
121
+ def training_mode(self, training_mode: TrainingModeName | list[TrainingModeName]):
122
+ valid_modes = {"pointwise", "pairwise", "listwise"}
123
+ if isinstance(training_mode, list):
124
+ training_modes = list(training_mode)
125
+ if len(training_modes) != self.nums_task:
126
+ raise ValueError(
127
+ "[BaseModel-init Error] training_mode list length must match number of tasks."
128
+ )
129
+ else:
130
+ training_modes = [training_mode] * self.nums_task
131
+ if any(mode not in valid_modes for mode in training_modes):
132
+ raise ValueError(
133
+ "[BaseModel-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
134
+ )
135
+ self.training_modes = list(training_modes)
136
+
113
137
  def __init__(
114
138
  self,
115
139
  dense_features: list[DenseFeature] | None = None,
@@ -193,10 +217,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
193
217
 
194
218
  self.task = task or self.default_task
195
219
  self.nums_task = len(self.task) if isinstance(self.task, list) else 1
196
- self.training_modes = get_training_modes(training_mode, self.nums_task)
197
- self.training_mode = (
198
- self.training_modes if self.nums_task > 1 else self.training_modes[0]
199
- )
220
+
221
+ self.training_mode = training_mode
200
222
 
201
223
  self.embedding_l1_reg = embedding_l1_reg
202
224
  self.dense_l1_reg = dense_l1_reg
@@ -215,6 +237,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
215
237
 
216
238
  self.train_data_summary = None
217
239
  self.valid_data_summary = None
240
+ self.note = None
218
241
 
219
242
  def register_regularization_weights(
220
243
  self,
@@ -222,6 +245,15 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
222
245
  exclude_modules: list[str] | None = None,
223
246
  include_modules: list[str] | None = None,
224
247
  ):
248
+ """
249
+ Register parameters for regularization.
250
+ By default, all nn.Linear weights (excluding those in BatchNorm/Dropout layers) and embedding weights under `embedding_attr` are registered.
251
+
252
+ Args:
253
+ embedding_attr: Attribute name of the embedding layer/module.
254
+ exclude_modules: List of module name substrings to exclude from regularization.
255
+ include_modules: List of module name substrings to include for regularization. If provided, only modules containing these substrings are included.
256
+ """
225
257
  exclude_modules = exclude_modules or []
226
258
  include_modules = include_modules or []
227
259
  embedding_layer = getattr(self, embedding_attr, None)
@@ -268,6 +300,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
268
300
  existing_reg_ids.add(id(module.weight))
269
301
 
270
302
  def add_reg_loss(self) -> torch.Tensor:
303
+ """
304
+ Compute the regularization loss based on registered parameters and their respective regularization strengths.
305
+ """
271
306
  reg_loss = torch.tensor(0.0, device=self.device)
272
307
 
273
308
  if self.embedding_l1_reg > 0:
@@ -289,9 +324,25 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
289
324
  )
290
325
  return reg_loss
291
326
 
327
+ # todo: support build pairwise/listwise label in input
292
328
  def get_input(self, input_data: dict, require_labels: bool = True):
329
+ """
330
+ Prepare unified input features and labels from the given input data.
331
+
332
+
333
+ Args:
334
+ input_data: Input data dictionary containing 'features' and optionally 'labels', e.g., {'features': {'feat1': [...], 'feat2': [...]}, 'labels': {'label': [...]}}.
335
+ require_labels: Whether labels are required in the input data. Default is True: for training and evaluation with labels.
336
+
337
+ Note:
338
+ target tensor shape will always be (batch_size, num_targets)
339
+ """
293
340
  feature_source = input_data.get("features", {})
341
+ # todo: pairwise/listwise label support
342
+ # "labels": {...} should contain pointwise/pair index/list index/ relevance scores
343
+ # now only have pointwise label support
294
344
  label_source = input_data.get("labels")
345
+
295
346
  X_input = {}
296
347
  for feature in self.all_features:
297
348
  if feature.name not in feature_source:
@@ -307,13 +358,14 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
307
358
  device=self.device,
308
359
  )
309
360
  y = None
361
+ # if need labels: training or eval with labels
310
362
  if len(self.target_columns) > 0 and (
311
363
  require_labels
312
364
  or (
313
365
  label_source
314
366
  and any(name in label_source for name in self.target_columns)
315
367
  )
316
- ): # need labels: training or eval with labels
368
+ ):
317
369
  target_tensors = []
318
370
  for target_name in self.target_columns:
319
371
  if label_source is None or target_name not in label_source:
@@ -358,6 +410,10 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
358
410
  This function will split training data into training and validation sets when:
359
411
  1. valid_data is None;
360
412
  2. valid_split is provided.
413
+
414
+ Returns:
415
+ train_loader: DataLoader for training data.
416
+ valid_split_data: Validation data dict/dataframe split from training data.
361
417
  """
362
418
  if not (0 < valid_split < 1):
363
419
  raise ValueError(
@@ -375,7 +431,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
375
431
  )
376
432
  else:
377
433
  raise TypeError(
378
- f"[BaseModel-validation Error] If you want to use valid_split, train_data must be a pandas DataFrame or a dict instead of {type(train_data)}"
434
+ f"[BaseModel-validation Error] If you want to use valid_split, train_data must be DataFrame or a dict, now got {type(train_data)}"
379
435
  )
380
436
  rng = np.random.default_rng(42)
381
437
  indices = rng.permutation(total_length)
@@ -426,7 +482,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
426
482
  Args:
427
483
  optimizer: Optimizer name or instance. e.g., 'adam', 'sgd', or torch.optim.Adam().
428
484
  optimizer_params: Optimizer parameters. e.g., {'lr': 1e-3, 'weight_decay': 1e-5}.
429
- scheduler: Learning rate scheduler name or instance. e.g., 'step_lr', 'cosine_annealing', or torch.optim.lr_scheduler.StepLR().
485
+ scheduler: Learning rate scheduler name or instance. e.g., 'step', 'cosine', or torch.optim.lr_scheduler.StepLR().
430
486
  scheduler_params: Scheduler parameters. e.g., {'step_size': 10, 'gamma': 0.1}.
431
487
  loss: Loss function name, instance, or list for multi-task. e.g., 'bce', 'mse', or torch.nn.BCELoss(), you can also use custom loss functions.
432
488
  loss_params: Loss function parameters, or list for multi-task. e.g., {'weight': tensor([0.25, 0.75])}.
@@ -435,36 +491,31 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
435
491
  ignore_label: Label value to ignore when computing loss. Use this to skip gradients for unknown labels.
436
492
  """
437
493
  self.ignore_label = ignore_label
438
- default_losses = {
439
- "pointwise": "bce",
440
- "pairwise": "bpr",
441
- "listwise": "listnet",
442
- }
443
494
  loss_list = get_loss_list(
444
- loss, self.training_modes, self.nums_task, default_losses
495
+ loss, self.training_modes, self.nums_task
445
496
  )
446
- self.loss_params = loss_params or {}
447
- optimizer_params = optimizer_params or {}
497
+
498
+ self.loss_params = {} if loss_params is None else loss_params
499
+ self.optimizer_params = optimizer_params or {}
500
+ self.scheduler_params = scheduler_params or {}
501
+
448
502
  self.optimizer_name = (
449
503
  optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
450
504
  )
451
- self.optimizer_params = optimizer_params
452
505
  self.optimizer_fn = get_optimizer(
453
506
  optimizer=optimizer,
454
507
  params=self.parameters(),
455
- **optimizer_params,
508
+ **self.optimizer_params,
456
509
  )
457
510
 
458
- scheduler_params = scheduler_params or {}
459
511
  if scheduler is None:
460
512
  self.scheduler_name = None
461
513
  elif isinstance(scheduler, str):
462
514
  self.scheduler_name = scheduler
463
515
  else:
464
516
  self.scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__) # type: ignore
465
- self.scheduler_params = scheduler_params
466
517
  self.scheduler_fn = (
467
- get_scheduler(scheduler, self.optimizer_fn, **scheduler_params)
518
+ get_scheduler(scheduler, self.optimizer_fn, **self.scheduler_params)
468
519
  if scheduler
469
520
  else None
470
521
  )
@@ -482,35 +533,54 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
482
533
  for i in range(self.nums_task)
483
534
  ]
484
535
 
536
+ # loss weighting (grad norm or fixed weights)
485
537
  self.grad_norm = None
486
538
  self.grad_norm_shared_params = None
487
- if isinstance(loss_weights, str) and loss_weights.lower() == "grad_norm":
488
- if self.nums_task == 1:
489
- raise ValueError(
490
- "[BaseModel-compile Error] GradNorm requires multi-task setup."
491
- )
492
- self.grad_norm = GradNormLossWeighting(
493
- nums_task=self.nums_task, device=self.device
494
- )
495
- self.loss_weights = None
496
- elif (
497
- isinstance(loss_weights, dict) and loss_weights.get("method") == "grad_norm"
498
- ):
539
+ is_grad_norm = (
540
+ loss_weights == "grad_norm"
541
+ or isinstance(loss_weights, dict)
542
+ and loss_weights.get("method") == "grad_norm"
543
+ )
544
+ if is_grad_norm:
499
545
  if self.nums_task == 1:
500
546
  raise ValueError(
501
547
  "[BaseModel-compile Error] GradNorm requires multi-task setup."
502
548
  )
503
- grad_norm_params = dict(loss_weights)
549
+ grad_norm_params = dict(loss_weights) if isinstance(loss_weights, dict) else {}
504
550
  grad_norm_params.pop("method", None)
505
551
  self.grad_norm = GradNormLossWeighting(
506
552
  nums_task=self.nums_task, device=self.device, **grad_norm_params
507
553
  )
508
554
  self.loss_weights = None
555
+ elif loss_weights is None:
556
+ self.loss_weights = None
557
+ elif self.nums_task == 1:
558
+ if isinstance(loss_weights, (list, tuple)):
559
+ if len(loss_weights) != 1:
560
+ raise ValueError(
561
+ "[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup."
562
+ )
563
+ loss_weights = loss_weights[0]
564
+ self.loss_weights = [float(loss_weights)]
565
+ elif isinstance(loss_weights, (int, float)):
566
+ self.loss_weights = [float(loss_weights)] * self.nums_task
567
+ elif isinstance(loss_weights, (list, tuple)):
568
+ weights = [float(w) for w in loss_weights]
569
+ if len(weights) != self.nums_task:
570
+ raise ValueError(
571
+ f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task})."
572
+ )
573
+ self.loss_weights = weights
509
574
  else:
510
- self.loss_weights = resolve_loss_weights(loss_weights, self.nums_task)
575
+ raise TypeError(
576
+ f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}"
577
+ )
511
578
  self.compiled = True
512
579
 
513
580
  def compute_loss(self, y_pred, y_true):
581
+ """
582
+ Compute the loss between predictions and ground truth labels, with loss weighting and ignore_label handling
583
+ """
514
584
  if y_true is None:
515
585
  raise ValueError(
516
586
  "[BaseModel-compute_loss Error] Ground truth labels (y_true) are required."
@@ -522,13 +592,11 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
522
592
  y_pred = y_pred.view(-1, 1)
523
593
  if y_true.dim() == 1:
524
594
  y_true = y_true.view(-1, 1)
525
- if y_pred.shape != y_true.shape:
526
- raise ValueError(
527
- f"[BaseModel-compute_loss Error] Shape mismatch: {y_pred.shape} vs {y_true.shape}"
528
- )
529
595
 
530
596
  loss_fn = self.loss_fn[0]
531
-
597
+
598
+ # mask ignored labels
599
+ # we don't suggest using ignore_label for single task training
532
600
  if self.ignore_label is not None:
533
601
  valid_mask = y_true != self.ignore_label
534
602
  if valid_mask.dim() > 1:
@@ -559,9 +627,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
559
627
  loss *= self.loss_weights[0]
560
628
  return loss
561
629
 
562
- # multi-task
563
- if y_pred.shape != y_true.shape:
564
- raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
630
+ # multi-task: slice predictions and labels per task
565
631
  slices = (
566
632
  self.prediction_layer.task_slices # type: ignore
567
633
  if hasattr(self, "prediction_layer")
@@ -593,9 +659,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
593
659
  )
594
660
  else:
595
661
  task_loss = self.loss_fn[i](y_pred_i, y_true_i)
596
- # task_loss = normalize_task_loss(
597
- # task_loss, valid_count, total_count
598
- # ) # normalize by valid samples to avoid loss scale issues
662
+ # task_loss = normalize_task_loss(
663
+ # task_loss, valid_count, total_count
664
+ # ) # normalize by valid samples to avoid loss scale issues
599
665
  task_losses.append(task_loss)
600
666
 
601
667
  if self.grad_norm is not None:
@@ -624,6 +690,16 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
624
690
  ):
625
691
  """
626
692
  Prepare a DataLoader from input data. Only used when input data is not a DataLoader.
693
+
694
+ Args:
695
+ data: Input data (dict/df/DataLoader).
696
+ batch_size: Batch size.
697
+ shuffle: Whether to shuffle the data (ignored when a sampler is provided).
698
+ num_workers: Number of DataLoader workers.
699
+ sampler: Optional sampler for DataLoader.
700
+ return_dataset: Whether to return the tensor dataset along with the DataLoader, used for valid data
701
+ Returns:
702
+ DataLoader (and tensor dataset if return_dataset is True).
627
703
  """
628
704
  if isinstance(data, DataLoader):
629
705
  return (data, None) if return_dataset else data
@@ -676,6 +752,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
676
752
  swanlab_kwargs: dict | None = None,
677
753
  auto_ddp_sampler: bool = True,
678
754
  log_interval: int = 1,
755
+ note: str | None = None,
679
756
  summary_sections: (
680
757
  list[Literal["feature", "model", "train", "data"]] | None
681
758
  ) = None,
@@ -707,6 +784,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
707
784
  swanlab_kwargs: Optional kwargs for swanlab.init(...).
708
785
  auto_ddp_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
709
786
  log_interval: Log validation metrics every N epochs (still computes metrics each epoch).
787
+ note: Optional note for the training run.
710
788
  summary_sections: Optional summary sections to print. Choose from
711
789
  ["feature", "model", "train", "data"]. Defaults to all.
712
790
 
@@ -770,11 +848,13 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
770
848
  self.metrics_sample_limit = (
771
849
  None if metrics_sample_limit is None else int(metrics_sample_limit)
772
850
  )
851
+ self.note = note
773
852
 
774
853
  training_config = {}
775
854
  if self.is_main_process:
776
855
  training_config = {
777
856
  "model_name": getattr(self, "model_name", self.__class__.__name__),
857
+ "note": self.note,
778
858
  "task": self.task,
779
859
  "target_columns": self.target_columns,
780
860
  "batch_size": batch_size,
@@ -1253,7 +1333,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1253
1333
  for batch_index, batch_data in batch_iter:
1254
1334
  batch_dict = batch_to_dict(batch_data)
1255
1335
  X_input, y_true = self.get_input(batch_dict, require_labels=True)
1256
- # call via __call__ so DDP hooks run (no grad sync if calling .forward directly)
1336
+ # call via __call__ so DDP hooks run
1257
1337
  y_pred = model(X_input) # type: ignore
1258
1338
 
1259
1339
  loss = self.compute_loss(y_pred, y_true)
@@ -1556,7 +1636,6 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1556
1636
  num_workers: int = 0,
1557
1637
  ) -> pd.DataFrame | np.ndarray | Path | None:
1558
1638
  """
1559
- Note: predict does not support distributed mode currently, consider it as a single-process operation.
1560
1639
  Make predictions on the given data.
1561
1640
 
1562
1641
  Args:
@@ -1569,6 +1648,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1569
1648
  return_dataframe: Whether to return predictions as a pandas DataFrame; if False, returns a NumPy array.
1570
1649
  stream_chunk_size: Number of rows per chunk when using streaming mode for large datasets.
1571
1650
  num_workers: DataLoader worker count.
1651
+
1652
+ Note:
1653
+ predict does not support distributed mode currently, consider it as a single-process operation.
1572
1654
  """
1573
1655
  self.eval()
1574
1656
  # Use prediction-time id_columns if provided, otherwise fall back to model's id_columns
@@ -1753,6 +1835,21 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1753
1835
  return_dataframe: bool,
1754
1836
  id_columns: list[str] | None = None,
1755
1837
  ):
1838
+ """
1839
+ Make predictions on the given data using streaming mode for large datasets.
1840
+
1841
+ Args:
1842
+ data: Input data for prediction (file path, dict, DataFrame, or DataLoader).
1843
+ batch_size: Batch size for prediction.
1844
+ save_path: Path to save predictions.
1845
+ save_format: Format to save predictions ('csv' or 'parquet').
1846
+ include_ids: Whether to include ID columns in the output.
1847
+ stream_chunk_size: Number of rows per chunk when using streaming mode.
1848
+ return_dataframe: Whether to return predictions as a pandas DataFrame.
1849
+ id_columns: Column name(s) to use as IDs; if None, uses model's id_columns.
1850
+ Note:
1851
+ This method uses streaming writes to handle large datasets without loading all data into memory.
1852
+ """
1756
1853
  if isinstance(data, (str, os.PathLike)):
1757
1854
  rec_loader = RecDataLoader(
1758
1855
  dense_features=self.dense_features,
@@ -1795,8 +1892,6 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1795
1892
  "Results will be collected in memory and saved at the end. Use csv or parquet for true streaming."
1796
1893
  )
1797
1894
 
1798
- from nextrec.utils.data import FILE_FORMAT_CONFIG
1799
-
1800
1895
  suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
1801
1896
 
1802
1897
  target_path = get_save_path(
@@ -1908,6 +2003,14 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1908
2003
  add_timestamp: bool | None = None,
1909
2004
  verbose: bool = True,
1910
2005
  ):
2006
+ """
2007
+ Save the model state and features configuration to disk.
2008
+
2009
+ Args:
2010
+ save_path: Path to save the model; if None, saves to the session's model directory.
2011
+ add_timestamp: Whether to add a timestamp to the filename; if None, defaults to True.
2012
+ verbose: Whether to log the save location.
2013
+ """
1911
2014
  add_timestamp = False if add_timestamp is None else add_timestamp
1912
2015
  target_path = get_save_path(
1913
2016
  path=save_path,
@@ -1950,6 +2053,14 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1950
2053
  map_location: str | torch.device | None = "cpu",
1951
2054
  verbose: bool = True,
1952
2055
  ):
2056
+ """
2057
+ Load the model state and features configuration from disk.
2058
+
2059
+ Args:
2060
+ save_path: Path to load the model from; can be a directory or a specific .pt file.
2061
+ map_location: Device mapping for loading the model (e.g., 'cpu', 'cuda:0').
2062
+ verbose: Whether to log the load location.
2063
+ """
1953
2064
  self.to(self.device)
1954
2065
  base_path = Path(save_path)
1955
2066
  if base_path.is_dir():
@@ -2016,6 +2127,13 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
2016
2127
  """
2017
2128
  Load a model from a checkpoint path. The checkpoint path should contain:
2018
2129
  a .pt file and a features_config.pkl file.
2130
+
2131
+ Args:
2132
+ checkpoint_path: Path to the checkpoint directory or specific .pt file.
2133
+ map_location: Device mapping for loading the model (e.g., 'cpu', 'cuda:0').
2134
+ device: Device to place the model on after loading.
2135
+ session_id: Optional session ID for the model.
2136
+ **kwargs: Additional keyword arguments to pass to the model constructor.
2019
2137
  """
2020
2138
  base_path = Path(checkpoint_path)
2021
2139
  verbose = kwargs.pop("verbose", True)
@@ -2135,6 +2253,7 @@ class BaseMatchModel(BaseModel):
2135
2253
  target=target,
2136
2254
  id_columns=id_columns,
2137
2255
  task=task,
2256
+ training_mode=training_mode,
2138
2257
  device=device,
2139
2258
  embedding_l1_reg=embedding_l1_reg,
2140
2259
  dense_l1_reg=dense_l1_reg,
@@ -2157,10 +2276,13 @@ class BaseMatchModel(BaseModel):
2157
2276
  self.item_sparse_features = item_sparse_features
2158
2277
  self.item_sequence_features = item_sequence_features
2159
2278
 
2160
- self.training_mode = training_mode
2161
2279
  self.num_negative_samples = num_negative_samples
2162
2280
  self.temperature = temperature
2163
2281
  self.similarity_metric = similarity_metric
2282
+ if self.training_mode not in self.support_training_modes:
2283
+ raise ValueError(
2284
+ f"{self.model_name.upper()} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
2285
+ )
2164
2286
  self.user_features_all = (
2165
2287
  self.user_dense_features
2166
2288
  + self.user_sparse_features
@@ -2209,11 +2331,6 @@ class BaseMatchModel(BaseModel):
2209
2331
  loss_params: Parameters for the loss function(s). e.g., {'reduction': 'mean'}.
2210
2332
  loss_weights: Weights for the loss function(s). e.g., 1.0 or [0.7, 0.3].
2211
2333
  """
2212
- if self.training_mode not in self.support_training_modes:
2213
- raise ValueError(
2214
- f"{self.model_name.upper()} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
2215
- )
2216
-
2217
2334
  default_loss_by_mode = {
2218
2335
  "pointwise": "bce",
2219
2336
  "pairwise": "bpr",