nextrec 0.4.21__tar.gz → 0.4.23__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. {nextrec-0.4.21 → nextrec-0.4.23}/PKG-INFO +8 -6
  2. {nextrec-0.4.21 → nextrec-0.4.23}/README.md +7 -5
  3. {nextrec-0.4.21 → nextrec-0.4.23}/README_en.md +5 -5
  4. {nextrec-0.4.21 → nextrec-0.4.23}/docs/en/Getting started guide.md +1 -0
  5. {nextrec-0.4.21 → nextrec-0.4.23}/docs/rtd/conf.py +1 -1
  6. {nextrec-0.4.21 → nextrec-0.4.23}/docs/rtd/index.md +1 -0
  7. {nextrec-0.4.21 → nextrec-0.4.23}/docs/zh//345/277/253/351/200/237/344/270/212/346/211/213.md +1 -0
  8. nextrec-0.4.23/nextrec/__version__.py +1 -0
  9. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/basic/activation.py +1 -1
  10. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/basic/heads.py +2 -3
  11. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/basic/metrics.py +1 -2
  12. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/basic/model.py +115 -80
  13. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/basic/summary.py +36 -2
  14. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/data/preprocessor.py +137 -5
  15. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/loss/__init__.py +0 -4
  16. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/loss/grad_norm.py +3 -3
  17. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/loss/listwise.py +19 -6
  18. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/loss/pairwise.py +6 -4
  19. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/loss/pointwise.py +8 -6
  20. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/multi_task/esmm.py +3 -26
  21. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/multi_task/mmoe.py +2 -24
  22. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/multi_task/ple.py +13 -35
  23. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/multi_task/poso.py +4 -28
  24. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/multi_task/share_bottom.py +1 -24
  25. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/afm.py +3 -27
  26. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/autoint.py +5 -38
  27. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/dcn.py +1 -26
  28. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/dcn_v2.py +5 -33
  29. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/deepfm.py +2 -29
  30. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/dien.py +2 -28
  31. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/din.py +2 -27
  32. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/eulernet.py +3 -30
  33. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/ffm.py +0 -26
  34. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/fibinet.py +8 -32
  35. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/fm.py +0 -29
  36. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/lr.py +0 -30
  37. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/masknet.py +4 -30
  38. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/pnn.py +4 -28
  39. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/widedeep.py +0 -32
  40. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/xdeepfm.py +0 -30
  41. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/retrieval/dssm.py +0 -24
  42. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/retrieval/dssm_v2.py +0 -24
  43. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/retrieval/mind.py +0 -20
  44. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/retrieval/sdm.py +0 -20
  45. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/retrieval/youtube_dnn.py +0 -21
  46. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/sequential/hstu.py +0 -18
  47. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/utils/__init__.py +5 -1
  48. nextrec-0.4.21/nextrec/loss/loss_utils.py → nextrec-0.4.23/nextrec/utils/loss.py +17 -7
  49. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/utils/model.py +79 -1
  50. nextrec-0.4.23/nextrec/utils/types.py +98 -0
  51. {nextrec-0.4.21 → nextrec-0.4.23}/pyproject.toml +1 -1
  52. {nextrec-0.4.21 → nextrec-0.4.23}/test/test_losses.py +54 -1
  53. {nextrec-0.4.21 → nextrec-0.4.23}/test/test_ranking_models.py +2 -3
  54. {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/example_multitask.py +1 -8
  55. {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/example_ranking_din.py +3 -5
  56. {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/notebooks/en/Hands on nextrec.ipynb +1 -1
  57. {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/notebooks/zh//345/277/253/351/200/237/345/205/245/351/227/250nextrec.ipynb +1 -1
  58. nextrec-0.4.21/nextrec/__version__.py +0 -1
  59. nextrec-0.4.21/nextrec/utils/types.py +0 -59
  60. {nextrec-0.4.21 → nextrec-0.4.23}/.github/workflows/publish.yml +0 -0
  61. {nextrec-0.4.21 → nextrec-0.4.23}/.github/workflows/tests.yml +0 -0
  62. {nextrec-0.4.21 → nextrec-0.4.23}/.gitignore +0 -0
  63. {nextrec-0.4.21 → nextrec-0.4.23}/.readthedocs.yaml +0 -0
  64. {nextrec-0.4.21 → nextrec-0.4.23}/CODE_OF_CONDUCT.md +0 -0
  65. {nextrec-0.4.21 → nextrec-0.4.23}/CONTRIBUTING.md +0 -0
  66. {nextrec-0.4.21 → nextrec-0.4.23}/LICENSE +0 -0
  67. {nextrec-0.4.21 → nextrec-0.4.23}/MANIFEST.in +0 -0
  68. {nextrec-0.4.21 → nextrec-0.4.23}/assets/Feature Configuration.png +0 -0
  69. {nextrec-0.4.21 → nextrec-0.4.23}/assets/Model Parameters.png +0 -0
  70. {nextrec-0.4.21 → nextrec-0.4.23}/assets/Training Configuration.png +0 -0
  71. {nextrec-0.4.21 → nextrec-0.4.23}/assets/Training logs.png +0 -0
  72. {nextrec-0.4.21 → nextrec-0.4.23}/assets/logo.png +0 -0
  73. {nextrec-0.4.21 → nextrec-0.4.23}/assets/mmoe_tutorial.png +0 -0
  74. {nextrec-0.4.21 → nextrec-0.4.23}/assets/nextrec_diagram.png +0 -0
  75. {nextrec-0.4.21 → nextrec-0.4.23}/assets/test data.png +0 -0
  76. {nextrec-0.4.21 → nextrec-0.4.23}/dataset/ctcvr_task.csv +0 -0
  77. {nextrec-0.4.21 → nextrec-0.4.23}/dataset/ecommerce_task.csv +0 -0
  78. {nextrec-0.4.21 → nextrec-0.4.23}/dataset/match_task.csv +0 -0
  79. {nextrec-0.4.21 → nextrec-0.4.23}/dataset/movielens_100k.csv +0 -0
  80. {nextrec-0.4.21 → nextrec-0.4.23}/dataset/multitask_task.csv +0 -0
  81. {nextrec-0.4.21 → nextrec-0.4.23}/dataset/ranking_task.csv +0 -0
  82. {nextrec-0.4.21 → nextrec-0.4.23}/docs/rtd/Makefile +0 -0
  83. {nextrec-0.4.21 → nextrec-0.4.23}/docs/rtd/make.bat +0 -0
  84. {nextrec-0.4.21 → nextrec-0.4.23}/docs/rtd/modules.rst +0 -0
  85. {nextrec-0.4.21 → nextrec-0.4.23}/docs/rtd/nextrec.basic.rst +0 -0
  86. {nextrec-0.4.21 → nextrec-0.4.23}/docs/rtd/nextrec.data.rst +0 -0
  87. {nextrec-0.4.21 → nextrec-0.4.23}/docs/rtd/nextrec.loss.rst +0 -0
  88. {nextrec-0.4.21 → nextrec-0.4.23}/docs/rtd/nextrec.rst +0 -0
  89. {nextrec-0.4.21 → nextrec-0.4.23}/docs/rtd/nextrec.utils.rst +0 -0
  90. {nextrec-0.4.21 → nextrec-0.4.23}/docs/rtd/requirements.txt +0 -0
  91. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/__init__.py +0 -0
  92. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/basic/__init__.py +0 -0
  93. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/basic/callback.py +0 -0
  94. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/basic/features.py +0 -0
  95. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/basic/layers.py +0 -0
  96. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/basic/loggers.py +0 -0
  97. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/basic/session.py +0 -0
  98. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/cli.py +0 -0
  99. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/data/__init__.py +0 -0
  100. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/data/batch_utils.py +0 -0
  101. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/data/data_processing.py +0 -0
  102. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/data/data_utils.py +0 -0
  103. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/data/dataloader.py +0 -0
  104. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/generative/__init__.py +0 -0
  105. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/generative/tiger.py +0 -0
  106. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/multi_task/__init__.py +0 -0
  107. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/__init__.py +0 -0
  108. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/representation/__init__.py +0 -0
  109. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/representation/autorec.py +0 -0
  110. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/representation/bpr.py +0 -0
  111. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/representation/cl4srec.py +0 -0
  112. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/representation/lightgcn.py +0 -0
  113. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/representation/mf.py +0 -0
  114. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/representation/rqvae.py +0 -0
  115. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/representation/s3rec.py +0 -0
  116. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/retrieval/__init__.py +0 -0
  117. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/sequential/sasrec.py +0 -0
  118. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/utils/config.py +0 -0
  119. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/utils/console.py +0 -0
  120. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/utils/data.py +0 -0
  121. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/utils/embedding.py +0 -0
  122. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/utils/feature.py +0 -0
  123. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/utils/torch_utils.py +0 -0
  124. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/NextRec-CLI.md +0 -0
  125. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/NextRec-CLI_zh.md +0 -0
  126. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/feature_config.yaml +0 -0
  127. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/afm.yaml +0 -0
  128. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/autoint.yaml +0 -0
  129. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/dcn.yaml +0 -0
  130. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/deepfm.yaml +0 -0
  131. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/din.yaml +0 -0
  132. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/esmm.yaml +0 -0
  133. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/fibinet.yaml +0 -0
  134. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/fm.yaml +0 -0
  135. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/masknet.yaml +0 -0
  136. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/mmoe.yaml +0 -0
  137. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/ple.yaml +0 -0
  138. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/pnn.yaml +0 -0
  139. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/poso.yaml +0 -0
  140. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/share_bottom.yaml +0 -0
  141. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/widedeep.yaml +0 -0
  142. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/xdeepfm.yaml +0 -0
  143. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/predict_config.yaml +0 -0
  144. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/predict_config_template.yaml +0 -0
  145. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/train_config.yaml +0 -0
  146. {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/train_config_template.yaml +0 -0
  147. {nextrec-0.4.21 → nextrec-0.4.23}/pytest.ini +0 -0
  148. {nextrec-0.4.21 → nextrec-0.4.23}/requirements.txt +0 -0
  149. {nextrec-0.4.21 → nextrec-0.4.23}/scripts/format_code.py +0 -0
  150. {nextrec-0.4.21 → nextrec-0.4.23}/test/__init__.py +0 -0
  151. {nextrec-0.4.21 → nextrec-0.4.23}/test/conftest.py +0 -0
  152. {nextrec-0.4.21 → nextrec-0.4.23}/test/helpers.py +0 -0
  153. {nextrec-0.4.21 → nextrec-0.4.23}/test/run_tests.py +0 -0
  154. {nextrec-0.4.21 → nextrec-0.4.23}/test/test_base_model_regularization.py +0 -0
  155. {nextrec-0.4.21 → nextrec-0.4.23}/test/test_generative_models.py +0 -0
  156. {nextrec-0.4.21 → nextrec-0.4.23}/test/test_layers.py +0 -0
  157. {nextrec-0.4.21 → nextrec-0.4.23}/test/test_match_models.py +0 -0
  158. {nextrec-0.4.21 → nextrec-0.4.23}/test/test_multitask_models.py +0 -0
  159. {nextrec-0.4.21 → nextrec-0.4.23}/test/test_preprocessor.py +0 -0
  160. {nextrec-0.4.21 → nextrec-0.4.23}/test/test_utils_console.py +0 -0
  161. {nextrec-0.4.21 → nextrec-0.4.23}/test/test_utils_data.py +0 -0
  162. {nextrec-0.4.21 → nextrec-0.4.23}/test/test_utils_embedding.py +0 -0
  163. {nextrec-0.4.21 → nextrec-0.4.23}/test_requirements.txt +0 -0
  164. {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/distributed/example_distributed_training.py +0 -0
  165. {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/distributed/example_distributed_training_large_dataset.py +0 -0
  166. {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/example_match.py +0 -0
  167. {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/movielen_match_dssm.py +0 -0
  168. {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/movielen_ranking_deepfm.py +0 -0
  169. {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/notebooks/en/Build semantic ID with RQ-VAE.ipynb +0 -0
  170. {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/notebooks/en/Hands on dataprocessor.ipynb +0 -0
  171. {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/notebooks/zh//344/275/277/347/224/250RQ-VAE/346/236/204/345/273/272/350/257/255/344/271/211ID.ipynb" +0 -0
  172. {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/notebooks/zh//345/246/202/344/275/225/344/275/277/347/224/250DataProcessor/350/277/233/350/241/214/351/242/204/345/244/204/347/220/206.ipynb" +0 -0
  173. {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/run_all_match_models.py +0 -0
  174. {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/run_all_multitask_models.py +0 -0
  175. {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/run_all_ranking_models.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nextrec
3
- Version: 0.4.21
3
+ Version: 0.4.23
4
4
  Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
5
5
  Project-URL: Homepage, https://github.com/zerolovesea/NextRec
6
6
  Project-URL: Repository, https://github.com/zerolovesea/NextRec
@@ -65,11 +65,11 @@ Description-Content-Type: text/markdown
65
65
 
66
66
  <div align="center">
67
67
 
68
- [![PyPI Downloads](https://static.pepy.tech/personalized-badge/nextrec?period=total&units=NONE&left_color=BLACK&right_color=GREEN&left_text=PyPI-downloads)](https://pypistats.org/packages/nextrec)
68
+ [![PyPI Downloads](https://static.pepy.tech/personalized-badge/nextrec?period=total&units=NONE&left_color=grey&right_color=GREEN&left_text=PyPI-downloads)](https://pypistats.org/packages/nextrec)
69
69
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
70
70
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
71
71
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
72
- ![Version](https://img.shields.io/badge/Version-0.4.21-orange.svg)
72
+ ![Version](https://img.shields.io/badge/Version-0.4.23-orange.svg)
73
73
  [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/zerolovesea/NextRec)
74
74
 
75
75
  中文文档 | [English Version](README_en.md)
@@ -191,6 +191,8 @@ model = DIN(
191
191
  dense_features=dense_features,
192
192
  sparse_features=sparse_features,
193
193
  sequence_features=sequence_features,
194
+ behavior_feature_name="sequence_0",
195
+ candidate_feature_name="item_id",
194
196
  mlp_params=mlp_params,
195
197
  attention_hidden_units=[80, 40],
196
198
  attention_activation='sigmoid',
@@ -204,7 +206,7 @@ model = DIN(
204
206
  session_id="din_tutorial", # 实验id,用于存放训练日志
205
207
  )
206
208
 
207
- # 编译模型,设置优化器和损失函数
209
+ # 编译模型,优化器/损失/学习率调度器统一在 compile 中设置
208
210
  model.compile(
209
211
  optimizer = "adam",
210
212
  optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5},
@@ -247,11 +249,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
247
249
 
248
250
  预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
249
251
 
250
- > 截止当前版本0.4.21,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
252
+ > 截止当前版本0.4.23,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
251
253
 
252
254
  ## 兼容平台
253
255
 
254
- 当前最新版本为0.4.21,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
256
+ 当前最新版本为0.4.23,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
255
257
 
256
258
  | 平台 | 配置 |
257
259
  |------|------|
@@ -4,11 +4,11 @@
4
4
 
5
5
  <div align="center">
6
6
 
7
- [![PyPI Downloads](https://static.pepy.tech/personalized-badge/nextrec?period=total&units=NONE&left_color=BLACK&right_color=GREEN&left_text=PyPI-downloads)](https://pypistats.org/packages/nextrec)
7
+ [![PyPI Downloads](https://static.pepy.tech/personalized-badge/nextrec?period=total&units=NONE&left_color=grey&right_color=GREEN&left_text=PyPI-downloads)](https://pypistats.org/packages/nextrec)
8
8
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
9
9
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
10
10
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
11
- ![Version](https://img.shields.io/badge/Version-0.4.21-orange.svg)
11
+ ![Version](https://img.shields.io/badge/Version-0.4.23-orange.svg)
12
12
  [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/zerolovesea/NextRec)
13
13
 
14
14
  中文文档 | [English Version](README_en.md)
@@ -130,6 +130,8 @@ model = DIN(
130
130
  dense_features=dense_features,
131
131
  sparse_features=sparse_features,
132
132
  sequence_features=sequence_features,
133
+ behavior_feature_name="sequence_0",
134
+ candidate_feature_name="item_id",
133
135
  mlp_params=mlp_params,
134
136
  attention_hidden_units=[80, 40],
135
137
  attention_activation='sigmoid',
@@ -143,7 +145,7 @@ model = DIN(
143
145
  session_id="din_tutorial", # 实验id,用于存放训练日志
144
146
  )
145
147
 
146
- # 编译模型,设置优化器和损失函数
148
+ # 编译模型,优化器/损失/学习率调度器统一在 compile 中设置
147
149
  model.compile(
148
150
  optimizer = "adam",
149
151
  optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5},
@@ -186,11 +188,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
186
188
 
187
189
  预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
188
190
 
189
- > 截止当前版本0.4.21,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
191
+ > 截止当前版本0.4.23,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
190
192
 
191
193
  ## 兼容平台
192
194
 
193
- 当前最新版本为0.4.21,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
195
+ 当前最新版本为0.4.23,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
194
196
 
195
197
  | 平台 | 配置 |
196
198
  |------|------|
@@ -4,11 +4,11 @@
4
4
 
5
5
  <div align="center">
6
6
 
7
- [![PyPI Downloads](https://static.pepy.tech/personalized-badge/nextrec?period=total&units=NONE&left_color=BLACK&right_color=GREEN&left_text=PyPI-downloads)](https://pypistats.org/packages/nextrec)
7
+ [![PyPI Downloads](https://static.pepy.tech/personalized-badge/nextrec?period=total&units=NONE&left_color=grey&right_color=GREEN&left_text=PyPI-downloads)](https://pypistats.org/packages/nextrec)
8
8
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
9
9
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
10
10
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
11
- ![Version](https://img.shields.io/badge/Version-0.4.21-orange.svg)
11
+ ![Version](https://img.shields.io/badge/Version-0.4.23-orange.svg)
12
12
  [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/zerolovesea/NextRec)
13
13
 
14
14
  English | [中文文档](README.md)
@@ -148,7 +148,7 @@ model = DIN(
148
148
  session_id="din_tutorial", # experiment id for logs
149
149
  )
150
150
 
151
- # Compile model with optimizer and loss
151
+ # Compile model; configure optimizer/loss/scheduler via compile()
152
152
  model.compile(
153
153
  optimizer = "adam",
154
154
  optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5},
@@ -191,11 +191,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
191
191
 
192
192
  Prediction outputs are saved under `{checkpoint_path}/predictions/{name}.{save_data_format}`.
193
193
 
194
- > As of version 0.4.21, NextRec CLI supports single-machine training; distributed training features are currently under development.
194
+ > As of version 0.4.23, NextRec CLI supports single-machine training; distributed training features are currently under development.
195
195
 
196
196
  ## Platform Compatibility
197
197
 
198
- The current version is 0.4.21. All models and test code have been validated on the following platforms. If you encounter compatibility issues, please report them in the issue tracker with your system version:
198
+ The current version is 0.4.23. All models and test code have been validated on the following platforms. If you encounter compatibility issues, please report them in the issue tracker with your system version:
199
199
 
200
200
  | Platform | Configuration |
201
201
  |----------|---------------|
@@ -55,6 +55,7 @@ model = DeepFM(
55
55
  session_id="movielens_deepfm", # manages logs and checkpoints
56
56
  )
57
57
 
58
+ # Optimizer/loss/scheduler are configured via compile()
58
59
  model.compile(
59
60
  optimizer="adam",
60
61
  optimizer_params={"lr": 1e-3, "weight_decay": 1e-5},
@@ -11,7 +11,7 @@ sys.path.insert(0, str(PROJECT_ROOT / "nextrec"))
11
11
  project = "NextRec"
12
12
  copyright = "2025, Yang Zhou"
13
13
  author = "Yang Zhou"
14
- release = "0.4.21"
14
+ release = "0.4.23"
15
15
 
16
16
  extensions = [
17
17
  "myst_parser",
@@ -57,6 +57,7 @@ model = DeepFM(
57
57
  session_id="deepfm_demo",
58
58
  )
59
59
 
60
+ # Configure optimizer/loss/scheduler via compile()
60
61
  model.compile(
61
62
  optimizer="adam",
62
63
  optimizer_params={"lr": 1e-3, "weight_decay": 1e-5},
@@ -55,6 +55,7 @@ model = DeepFM(
55
55
  session_id="movielens_deepfm", # 管理实验日志与检查点
56
56
  )
57
57
 
58
+ # 优化器/损失/学习率调度器统一在 compile 中设置
58
59
  model.compile(
59
60
  optimizer="adam",
60
61
  optimizer_params={"lr": 1e-3, "weight_decay": 1e-5},
@@ -0,0 +1 @@
1
+ __version__ = "0.4.23"
@@ -9,10 +9,10 @@ Author: Yang Zhou, zyaztec@gmail.com
9
9
  import torch
10
10
  import torch.nn as nn
11
11
 
12
- from typing import Literal
13
12
 
14
13
  from nextrec.utils.types import ActivationName
15
14
 
15
+
16
16
  class Dice(nn.Module):
17
17
  """
18
18
  Dice activation function from the paper:
@@ -15,6 +15,7 @@ import torch.nn as nn
15
15
  import torch.nn.functional as F
16
16
 
17
17
  from nextrec.basic.layers import PredictionLayer
18
+ from nextrec.utils.types import TaskTypeName
18
19
 
19
20
 
20
21
  class TaskHead(nn.Module):
@@ -27,9 +28,7 @@ class TaskHead(nn.Module):
27
28
 
28
29
  def __init__(
29
30
  self,
30
- task_type: (
31
- Literal["binary", "regression"] | list[Literal["binary", "regression"]]
32
- ) = "binary",
31
+ task_type: TaskTypeName | list[TaskTypeName] = "binary",
33
32
  task_dims: int | list[int] | None = None,
34
33
  use_bias: bool = True,
35
34
  return_logits: bool = False,
@@ -2,7 +2,7 @@
2
2
  Metrics computation and configuration for model evaluation.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 20/12/2025
5
+ Checkpoint: edit on 29/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
@@ -39,7 +39,6 @@ REGRESSION_METRICS = {"mse", "mae", "rmse", "r2", "mape", "msle"}
39
39
  TASK_DEFAULT_METRICS = {
40
40
  "binary": ["auc", "gauc", "ks", "logloss", "accuracy", "precision", "recall", "f1"],
41
41
  "regression": ["mse", "mae", "rmse", "r2", "mape"],
42
- "multilabel": ["auc", "hamming_loss", "subset_accuracy", "micro_f1", "macro_f1"],
43
42
  "matching": ["auc", "gauc", "precision@10", "hitrate@10", "map@10", "cosine"]
44
43
  + [f"recall@{k}" for k in (5, 10, 20)]
45
44
  + [f"ndcg@{k}" for k in (5, 10, 20)]
@@ -2,13 +2,14 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 28/12/2025
5
+ Checkpoint: edit on 29/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
9
9
  import getpass
10
10
  import logging
11
11
  import os
12
+ import sys
12
13
  import pickle
13
14
  import socket
14
15
  from pathlib import Path
@@ -16,6 +17,16 @@ from typing import Any, Literal
16
17
 
17
18
  import numpy as np
18
19
  import pandas as pd
20
+
21
+ try:
22
+ import swanlab # type: ignore
23
+ except ModuleNotFoundError:
24
+ swanlab = None
25
+ try:
26
+ import wandb # type: ignore
27
+ except ModuleNotFoundError:
28
+ wandb = None
29
+
19
30
  import torch
20
31
  import torch.distributed as dist
21
32
  import torch.nn as nn
@@ -60,8 +71,8 @@ from nextrec.loss import (
60
71
  InfoNCELoss,
61
72
  SampledSoftmaxLoss,
62
73
  TripletLoss,
63
- get_loss_fn,
64
74
  )
75
+ from nextrec.utils.loss import get_loss_fn
65
76
  from nextrec.loss.grad_norm import get_grad_norm_shared_params
66
77
  from nextrec.utils.console import display_metrics_table, progress
67
78
  from nextrec.utils.torch_utils import (
@@ -74,8 +85,20 @@ from nextrec.utils.torch_utils import (
74
85
  to_tensor,
75
86
  )
76
87
  from nextrec.utils.config import safe_value
77
- from nextrec.utils.model import compute_ranking_loss
78
- from nextrec.utils.types import LossName, OptimizerName, SchedulerName
88
+ from nextrec.utils.model import (
89
+ compute_ranking_loss,
90
+ get_loss_list,
91
+ resolve_loss_weights,
92
+ get_training_modes,
93
+ )
94
+ from nextrec.utils.types import (
95
+ LossName,
96
+ OptimizerName,
97
+ SchedulerName,
98
+ TrainingModeName,
99
+ TaskTypeName,
100
+ MetricsName,
101
+ )
79
102
 
80
103
 
81
104
  class BaseModel(SummarySet, FeatureSet, nn.Module):
@@ -84,7 +107,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
84
107
  raise NotImplementedError
85
108
 
86
109
  @property
87
- def default_task(self) -> str | list[str]:
110
+ def default_task(self) -> TaskTypeName | list[TaskTypeName]:
88
111
  raise NotImplementedError
89
112
 
90
113
  def __init__(
@@ -94,11 +117,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
94
117
  sequence_features: list[SequenceFeature] | None = None,
95
118
  target: list[str] | str | None = None,
96
119
  id_columns: list[str] | str | None = None,
97
- task: str | list[str] | None = None,
98
- training_mode: (
99
- Literal["pointwise", "pairwise", "listwise"]
100
- | list[Literal["pointwise", "pairwise", "listwise"]]
101
- ) = "pointwise",
120
+ task: TaskTypeName | list[TaskTypeName] | None = None,
121
+ training_mode: TrainingModeName | list[TrainingModeName] = "pointwise",
102
122
  embedding_l1_reg: float = 0.0,
103
123
  dense_l1_reg: float = 0.0,
104
124
  embedding_l2_reg: float = 0.0,
@@ -136,6 +156,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
136
156
  world_size: Number of processes (defaults to env WORLD_SIZE).
137
157
  local_rank: Local rank for selecting CUDA device (defaults to env LOCAL_RANK).
138
158
  ddp_find_unused_parameters: Default False, set it True only when exist unused parameters in ddp model, in most cases should be False.
159
+
160
+ Note:
161
+ Optimizer, scheduler, and loss are configured via compile().
139
162
  """
140
163
  super(BaseModel, self).__init__()
141
164
 
@@ -168,25 +191,12 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
168
191
  dense_features, sparse_features, sequence_features, target, id_columns
169
192
  )
170
193
 
171
- self.task = self.default_task if task is None else task
194
+ self.task = task or self.default_task
172
195
  self.nums_task = len(self.task) if isinstance(self.task, list) else 1
173
- if isinstance(training_mode, list):
174
- training_modes = list(training_mode)
175
- if len(training_modes) != self.nums_task:
176
- raise ValueError(
177
- "[BaseModel-init Error] training_mode list length must match number of tasks."
178
- )
179
- else:
180
- training_modes = [training_mode] * self.nums_task
181
- if any(
182
- mode not in {"pointwise", "pairwise", "listwise"}
183
- for mode in training_modes
184
- ):
185
- raise ValueError(
186
- "[BaseModel-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
187
- )
188
- self.training_modes = training_modes
189
- self.training_mode = training_modes if self.nums_task > 1 else training_modes[0]
196
+ self.training_modes = get_training_modes(training_mode, self.nums_task)
197
+ self.training_mode = (
198
+ self.training_modes if self.nums_task > 1 else self.training_modes[0]
199
+ )
190
200
 
191
201
  self.embedding_l1_reg = embedding_l1_reg
192
202
  self.dense_l1_reg = dense_l1_reg
@@ -194,7 +204,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
194
204
  self.dense_l2_reg = dense_l2_reg
195
205
  self.regularization_weights = []
196
206
  self.embedding_params = []
197
- self.loss_weight = None
207
+
208
+ self.ignore_label = None
209
+ self.compiled = False
198
210
 
199
211
  self.max_gradient_norm = 1.0
200
212
  self.logger_initialized = False
@@ -407,6 +419,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
407
419
  loss: LossName | nn.Module | list[LossName | nn.Module] | None = "bce",
408
420
  loss_params: dict | list[dict] | None = None,
409
421
  loss_weights: int | float | list[int | float] | dict | str | None = None,
422
+ ignore_label: int | float | None = -1,
410
423
  ):
411
424
  """
412
425
  Configure the model for training.
@@ -419,34 +432,17 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
419
432
  loss_params: Loss function parameters, or list for multi-task. e.g., {'weight': tensor([0.25, 0.75])}.
420
433
  loss_weights: Weights for each task loss, int/float for single-task or list for multi-task. e.g., 1.0, or [1.0, 0.5].
421
434
  Use "grad_norm" or {"method": "grad_norm", ...} to enable GradNorm for multi-task loss balancing.
435
+ ignore_label: Label value to ignore when computing loss. Use this to skip gradients for unknown labels.
422
436
  """
437
+ self.ignore_label = ignore_label
423
438
  default_losses = {
424
439
  "pointwise": "bce",
425
440
  "pairwise": "bpr",
426
441
  "listwise": "listnet",
427
442
  }
428
- effective_loss = loss
429
- if effective_loss is None:
430
- loss_list = [default_losses[mode] for mode in self.training_modes]
431
- elif isinstance(effective_loss, list):
432
- if not effective_loss:
433
- loss_list = [default_losses[mode] for mode in self.training_modes]
434
- else:
435
- if len(effective_loss) != self.nums_task:
436
- raise ValueError(
437
- f"[BaseModel-compile Error] Number of loss functions ({len(effective_loss)}) must match number of tasks ({self.nums_task})."
438
- )
439
- loss_list = list(effective_loss)
440
- else:
441
- loss_list = [effective_loss] * self.nums_task
442
-
443
- for idx, mode in enumerate(self.training_modes):
444
- if isinstance(loss_list[idx], str) and loss_list[idx] in {
445
- "bce",
446
- "binary_crossentropy",
447
- }:
448
- if mode in {"pairwise", "listwise"}:
449
- loss_list[idx] = default_losses[mode]
443
+ loss_list = get_loss_list(
444
+ loss, self.training_modes, self.nums_task, default_losses
445
+ )
450
446
  self.loss_params = loss_params or {}
451
447
  optimizer_params = optimizer_params or {}
452
448
  self.optimizer_name = (
@@ -510,36 +506,16 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
510
506
  nums_task=self.nums_task, device=self.device, **grad_norm_params
511
507
  )
512
508
  self.loss_weights = None
513
- elif loss_weights is None:
514
- self.loss_weights = None
515
- elif self.nums_task == 1:
516
- if isinstance(loss_weights, (list, tuple)):
517
- if len(loss_weights) != 1:
518
- raise ValueError(
519
- "[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup."
520
- )
521
- loss_weights = loss_weights[0]
522
- self.loss_weights = [float(loss_weights)] # type: ignore
523
509
  else:
524
- if isinstance(loss_weights, (int, float)):
525
- weights = [float(loss_weights)] * self.nums_task
526
- elif isinstance(loss_weights, (list, tuple)):
527
- weights = [float(w) for w in loss_weights]
528
- if len(weights) != self.nums_task:
529
- raise ValueError(
530
- f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task})."
531
- )
532
- else:
533
- raise TypeError(
534
- f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}"
535
- )
536
- self.loss_weights = weights
510
+ self.loss_weights = resolve_loss_weights(loss_weights, self.nums_task)
511
+ self.compiled = True
537
512
 
538
513
  def compute_loss(self, y_pred, y_true):
539
514
  if y_true is None:
540
515
  raise ValueError(
541
516
  "[BaseModel-compute_loss Error] Ground truth labels (y_true) are required."
542
517
  )
518
+
543
519
  # single-task
544
520
  if self.nums_task == 1:
545
521
  if y_pred.dim() == 1:
@@ -547,13 +523,24 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
547
523
  if y_true.dim() == 1:
548
524
  y_true = y_true.view(-1, 1)
549
525
  if y_pred.shape != y_true.shape:
550
- raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
551
- loss_fn = self.loss_fn[0] if getattr(self, "loss_fn", None) else None
552
- if loss_fn is None:
553
526
  raise ValueError(
554
- "[BaseModel-compute_loss Error] Loss function is not configured. Call compile() first."
527
+ f"[BaseModel-compute_loss Error] Shape mismatch: {y_pred.shape} vs {y_true.shape}"
555
528
  )
529
+
530
+ loss_fn = self.loss_fn[0]
531
+
532
+ if self.ignore_label is not None:
533
+ valid_mask = y_true != self.ignore_label
534
+ if valid_mask.dim() > 1:
535
+ valid_mask = valid_mask.all(dim=1)
536
+ if not torch.any(valid_mask): # if no valid labels, return zero loss
537
+ return y_pred.sum() * 0.0
538
+
539
+ y_pred = y_pred[valid_mask]
540
+ y_true = y_true[valid_mask]
541
+
556
542
  mode = self.training_modes[0]
543
+
557
544
  task_dim = (
558
545
  self.task_dims[0] if hasattr(self, "task_dims") else y_pred.shape[1] # type: ignore
559
546
  )
@@ -584,7 +571,19 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
584
571
  for i, (start, end) in enumerate(slices): # type: ignore
585
572
  y_pred_i = y_pred[:, start:end]
586
573
  y_true_i = y_true[:, start:end]
574
+ # mask ignored labels
575
+ if self.ignore_label is not None:
576
+ valid_mask = y_true_i != self.ignore_label
577
+ if valid_mask.dim() > 1:
578
+ valid_mask = valid_mask.all(dim=1)
579
+ if not torch.any(valid_mask):
580
+ task_losses.append(y_pred_i.sum() * 0.0)
581
+ continue
582
+ y_pred_i = y_pred_i[valid_mask]
583
+ y_true_i = y_true_i[valid_mask]
584
+
587
585
  mode = self.training_modes[i]
586
+
588
587
  if mode in {"pairwise", "listwise"}:
589
588
  task_loss = compute_ranking_loss(
590
589
  training_mode=mode,
@@ -594,7 +593,11 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
594
593
  )
595
594
  else:
596
595
  task_loss = self.loss_fn[i](y_pred_i, y_true_i)
596
+ # task_loss = normalize_task_loss(
597
+ # task_loss, valid_count, total_count
598
+ # ) # normalize by valid samples to avoid loss scale issues
597
599
  task_losses.append(task_loss)
600
+
598
601
  if self.grad_norm is not None:
599
602
  if self.grad_norm_shared_params is None:
600
603
  self.grad_norm_shared_params = get_grad_norm_shared_params(
@@ -651,7 +654,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
651
654
  train_data=None,
652
655
  valid_data=None,
653
656
  metrics: (
654
- list[str] | dict[str, list[str]] | None
657
+ list[MetricsName] | dict[str, list[MetricsName]] | None
655
658
  ) = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
656
659
  epochs: int = 1,
657
660
  shuffle: bool = True,
@@ -665,6 +668,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
665
668
  use_tensorboard: bool = True,
666
669
  use_wandb: bool = False,
667
670
  use_swanlab: bool = False,
671
+ wandb_api: str | None = None,
672
+ swanlab_api: str | None = None,
668
673
  wandb_kwargs: dict | None = None,
669
674
  swanlab_kwargs: dict | None = None,
670
675
  auto_ddp_sampler: bool = True,
@@ -694,6 +699,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
694
699
  use_tensorboard: Enable tensorboard logging.
695
700
  use_wandb: Enable Weights & Biases logging.
696
701
  use_swanlab: Enable SwanLab logging.
702
+ wandb_api: W&B API key for non-tty login.
703
+ swanlab_api: SwanLab API key for non-tty login.
697
704
  wandb_kwargs: Optional kwargs for wandb.init(...).
698
705
  swanlab_kwargs: Optional kwargs for swanlab.init(...).
699
706
  auto_ddp_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
@@ -711,6 +718,16 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
711
718
  )
712
719
  self.to(self.device)
713
720
 
721
+ if not self.compiled:
722
+ self.compile(
723
+ optimizer="adam",
724
+ optimizer_params={},
725
+ scheduler=None,
726
+ scheduler_params={},
727
+ loss=None,
728
+ loss_params={},
729
+ )
730
+
714
731
  if (
715
732
  self.distributed
716
733
  and dist.is_available()
@@ -785,6 +802,24 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
785
802
  }
786
803
  training_config: dict = safe_value(training_config) # type: ignore
787
804
 
805
+ if self.is_main_process:
806
+ is_tty = sys.stdin.isatty() and sys.stdout.isatty()
807
+ if not is_tty:
808
+ if use_wandb and wandb_api:
809
+ if wandb is None:
810
+ logging.warning(
811
+ "[BaseModel-fit] wandb not installed, skip wandb login."
812
+ )
813
+ else:
814
+ wandb.login(key=wandb_api)
815
+ if use_swanlab and swanlab_api:
816
+ if swanlab is None:
817
+ logging.warning(
818
+ "[BaseModel-fit] swanlab not installed, skip swanlab login."
819
+ )
820
+ else:
821
+ swanlab.login(api_key=swanlab_api)
822
+
788
823
  self.training_logger = (
789
824
  TrainingLogger(
790
825
  session=self.session,
@@ -2164,7 +2199,7 @@ class BaseMatchModel(BaseModel):
2164
2199
  scheduler_params: Parameters for the scheduler. e.g., {'step_size': 10, 'gamma': 0.1}.
2165
2200
  loss: Loss function(s) to use (name, instance, or list). e.g., 'bce'.
2166
2201
  loss_params: Parameters for the loss function(s). e.g., {'reduction': 'mean'}.
2167
- loss_weights: Weights for the loss function(s). e.g., 1.0 or [0.7, 0.3].
2202
+ loss_weights: Weights for the loss function(s). e.g., 1.0 or [0.7, 0.3].
2168
2203
  """
2169
2204
  if self.training_mode not in self.support_training_modes:
2170
2205
  raise ValueError(
@@ -1,5 +1,9 @@
1
1
  """
2
2
  Summary utilities for BaseModel.
3
+
4
+ Date: create on 03/12/2025
5
+ Checkpoint: edit on 29/12/2025
6
+ Author: Yang Zhou,zyaztec@gmail.com
3
7
  """
4
8
 
5
9
  from __future__ import annotations
@@ -12,9 +16,39 @@ from torch.utils.data import DataLoader
12
16
 
13
17
  from nextrec.basic.loggers import colorize, format_kv
14
18
  from nextrec.data.data_processing import extract_label_arrays, get_data_length
19
+ from nextrec.utils.types import TaskTypeName
15
20
 
16
21
 
17
22
  class SummarySet:
23
+ model_name: str
24
+ dense_features: list[Any]
25
+ sparse_features: list[Any]
26
+ sequence_features: list[Any]
27
+ task: TaskTypeName | list[TaskTypeName]
28
+ target_columns: list[str]
29
+ nums_task: int
30
+ metrics: Any
31
+ device: Any
32
+ optimizer_name: str
33
+ optimizer_params: dict[str, Any]
34
+ scheduler_name: str | None
35
+ scheduler_params: dict[str, Any]
36
+ loss_config: Any
37
+ loss_weights: Any
38
+ grad_norm: Any
39
+ embedding_l1_reg: float
40
+ embedding_l2_reg: float
41
+ dense_l1_reg: float
42
+ dense_l2_reg: float
43
+ early_stop_patience: int
44
+ max_gradient_norm: float | None
45
+ metrics_sample_limit: int | None
46
+ session_id: str | None
47
+ features_config_path: str
48
+ checkpoint_path: str
49
+ train_data_summary: dict[str, Any] | None
50
+ valid_data_summary: dict[str, Any] | None
51
+
18
52
  def build_data_summary(
19
53
  self, data: Any, data_loader: DataLoader | None, sample_key: str
20
54
  ):
@@ -305,7 +339,7 @@ class SummarySet:
305
339
  lines = details.get("lines", [])
306
340
  logger.info(f"{target_name}:")
307
341
  for label, value in lines:
308
- logger.info(format_kv(label, value))
342
+ logger.info(f" {format_kv(label, value)}")
309
343
 
310
344
  if self.valid_data_summary:
311
345
  if self.train_data_summary:
@@ -320,4 +354,4 @@ class SummarySet:
320
354
  lines = details.get("lines", [])
321
355
  logger.info(f"{target_name}:")
322
356
  for label, value in lines:
323
- logger.info(format_kv(label, value))
357
+ logger.info(f" {format_kv(label, value)}")