nextrec 0.4.33__tar.gz → 0.4.34__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. {nextrec-0.4.33 → nextrec-0.4.34}/PKG-INFO +4 -4
  2. {nextrec-0.4.33 → nextrec-0.4.34}/README.md +3 -3
  3. {nextrec-0.4.33 → nextrec-0.4.34}/README_en.md +3 -3
  4. {nextrec-0.4.33 → nextrec-0.4.34}/docs/rtd/conf.py +1 -1
  5. {nextrec-0.4.33 → nextrec-0.4.34}/docs/rtd/nextrec.utils.rst +0 -8
  6. nextrec-0.4.34/nextrec/__version__.py +1 -0
  7. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/basic/activation.py +14 -16
  8. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/basic/asserts.py +1 -22
  9. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/basic/callback.py +2 -2
  10. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/basic/features.py +6 -37
  11. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/basic/heads.py +13 -1
  12. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/basic/layers.py +9 -33
  13. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/basic/loggers.py +3 -2
  14. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/basic/metrics.py +85 -4
  15. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/basic/model.py +7 -4
  16. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/basic/summary.py +88 -42
  17. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/cli.py +16 -12
  18. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/data/preprocessor.py +3 -1
  19. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/loss/grad_norm.py +78 -76
  20. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/multi_task/ple.py +1 -0
  21. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/multi_task/share_bottom.py +1 -0
  22. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/tree_base/base.py +1 -1
  23. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/utils/__init__.py +2 -1
  24. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/utils/config.py +1 -1
  25. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/utils/console.py +1 -1
  26. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/utils/torch_utils.py +63 -56
  27. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/utils/types.py +43 -0
  28. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec_cli_preset/NextRec-CLI.md +0 -2
  29. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec_cli_preset/NextRec-CLI_zh.md +0 -2
  30. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec_cli_preset/predict_config.yaml +4 -3
  31. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec_cli_preset/predict_config_template.yaml +3 -2
  32. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec_cli_preset/train_config_template.yaml +0 -2
  33. {nextrec-0.4.33 → nextrec-0.4.34}/pyproject.toml +1 -1
  34. {nextrec-0.4.33 → nextrec-0.4.34}/test/run_tests.py +1 -0
  35. {nextrec-0.4.33 → nextrec-0.4.34}/tutorials/example_match.py +0 -1
  36. {nextrec-0.4.33 → nextrec-0.4.34}/tutorials/example_multitask.py +16 -3
  37. {nextrec-0.4.33 → nextrec-0.4.34}/tutorials/movielen_ranking_deepfm.py +0 -1
  38. nextrec-0.4.33/nextrec/__version__.py +0 -1
  39. nextrec-0.4.33/nextrec/models/representation/autorec.py +0 -0
  40. nextrec-0.4.33/nextrec/models/representation/bpr.py +0 -0
  41. nextrec-0.4.33/nextrec/models/representation/cl4srec.py +0 -0
  42. nextrec-0.4.33/nextrec/models/representation/lightgcn.py +0 -0
  43. nextrec-0.4.33/nextrec/models/representation/mf.py +0 -0
  44. nextrec-0.4.33/nextrec/models/representation/s3rec.py +0 -0
  45. nextrec-0.4.33/nextrec/models/sequential/sasrec.py +0 -0
  46. nextrec-0.4.33/nextrec/utils/feature.py +0 -29
  47. {nextrec-0.4.33 → nextrec-0.4.34}/.github/workflows/publish.yml +0 -0
  48. {nextrec-0.4.33 → nextrec-0.4.34}/.github/workflows/tests.yml +0 -0
  49. {nextrec-0.4.33 → nextrec-0.4.34}/.gitignore +0 -0
  50. {nextrec-0.4.33 → nextrec-0.4.34}/.readthedocs.yaml +0 -0
  51. {nextrec-0.4.33 → nextrec-0.4.34}/CODE_OF_CONDUCT.md +0 -0
  52. {nextrec-0.4.33 → nextrec-0.4.34}/CONTRIBUTING.md +0 -0
  53. {nextrec-0.4.33 → nextrec-0.4.34}/LICENSE +0 -0
  54. {nextrec-0.4.33 → nextrec-0.4.34}/MANIFEST.in +0 -0
  55. {nextrec-0.4.33 → nextrec-0.4.34}/assets/Feature Configuration.png +0 -0
  56. {nextrec-0.4.33 → nextrec-0.4.34}/assets/Model Parameters.png +0 -0
  57. {nextrec-0.4.33 → nextrec-0.4.34}/assets/Training Configuration.png +0 -0
  58. {nextrec-0.4.33 → nextrec-0.4.34}/assets/Training logs.png +0 -0
  59. {nextrec-0.4.33 → nextrec-0.4.34}/assets/logo.png +0 -0
  60. {nextrec-0.4.33 → nextrec-0.4.34}/assets/mmoe_tutorial.png +0 -0
  61. {nextrec-0.4.33 → nextrec-0.4.34}/assets/nextrec_diagram.png +0 -0
  62. {nextrec-0.4.33 → nextrec-0.4.34}/assets/test data.png +0 -0
  63. {nextrec-0.4.33 → nextrec-0.4.34}/dataset/ctcvr_task.csv +0 -0
  64. {nextrec-0.4.33 → nextrec-0.4.34}/dataset/ecommerce_task.csv +0 -0
  65. {nextrec-0.4.33 → nextrec-0.4.34}/dataset/match_task.csv +0 -0
  66. {nextrec-0.4.33 → nextrec-0.4.34}/dataset/movielens_100k.csv +0 -0
  67. {nextrec-0.4.33 → nextrec-0.4.34}/dataset/multitask_task.csv +0 -0
  68. {nextrec-0.4.33 → nextrec-0.4.34}/dataset/ranking_task.csv +0 -0
  69. {nextrec-0.4.33 → nextrec-0.4.34}/docs/en/Getting started guide.md +0 -0
  70. {nextrec-0.4.33 → nextrec-0.4.34}/docs/rtd/Makefile +0 -0
  71. {nextrec-0.4.33 → nextrec-0.4.34}/docs/rtd/index.md +0 -0
  72. {nextrec-0.4.33 → nextrec-0.4.34}/docs/rtd/make.bat +0 -0
  73. {nextrec-0.4.33 → nextrec-0.4.34}/docs/rtd/modules.rst +0 -0
  74. {nextrec-0.4.33 → nextrec-0.4.34}/docs/rtd/nextrec.basic.rst +0 -0
  75. {nextrec-0.4.33 → nextrec-0.4.34}/docs/rtd/nextrec.data.rst +0 -0
  76. {nextrec-0.4.33 → nextrec-0.4.34}/docs/rtd/nextrec.loss.rst +0 -0
  77. {nextrec-0.4.33 → nextrec-0.4.34}/docs/rtd/nextrec.rst +0 -0
  78. {nextrec-0.4.33 → nextrec-0.4.34}/docs/rtd/requirements.txt +0 -0
  79. {nextrec-0.4.33 → nextrec-0.4.34}/docs/zh//345/277/253/351/200/237/344/270/212/346/211/213.md" +0 -0
  80. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/__init__.py +0 -0
  81. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/basic/__init__.py +0 -0
  82. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/basic/session.py +0 -0
  83. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/data/__init__.py +0 -0
  84. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/data/batch_utils.py +0 -0
  85. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/data/data_processing.py +0 -0
  86. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/data/data_utils.py +0 -0
  87. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/data/dataloader.py +0 -0
  88. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/loss/__init__.py +0 -0
  89. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/loss/listwise.py +0 -0
  90. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/loss/pairwise.py +0 -0
  91. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/loss/pointwise.py +0 -0
  92. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/generative/__init__.py +0 -0
  93. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/generative/tiger.py +0 -0
  94. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/multi_task/[pre]aitm.py +0 -0
  95. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/multi_task/[pre]snr_trans.py +0 -0
  96. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/multi_task/[pre]star.py +0 -0
  97. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/multi_task/__init__.py +0 -0
  98. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/multi_task/apg.py +0 -0
  99. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/multi_task/cross_stitch.py +0 -0
  100. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/multi_task/escm.py +0 -0
  101. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/multi_task/esmm.py +0 -0
  102. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/multi_task/hmoe.py +0 -0
  103. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/multi_task/mmoe.py +0 -0
  104. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/multi_task/pepnet.py +0 -0
  105. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/multi_task/poso.py +0 -0
  106. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/ranking/__init__.py +0 -0
  107. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/ranking/afm.py +0 -0
  108. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/ranking/autoint.py +0 -0
  109. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/ranking/dcn.py +0 -0
  110. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/ranking/dcn_v2.py +0 -0
  111. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/ranking/deepfm.py +0 -0
  112. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/ranking/dien.py +0 -0
  113. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/ranking/din.py +0 -0
  114. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/ranking/eulernet.py +0 -0
  115. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/ranking/ffm.py +0 -0
  116. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/ranking/fibinet.py +0 -0
  117. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/ranking/fm.py +0 -0
  118. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/ranking/lr.py +0 -0
  119. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/ranking/masknet.py +0 -0
  120. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/ranking/pnn.py +0 -0
  121. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/ranking/widedeep.py +0 -0
  122. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/ranking/xdeepfm.py +0 -0
  123. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/representation/__init__.py +0 -0
  124. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/representation/rqvae.py +0 -0
  125. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/retrieval/__init__.py +0 -0
  126. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/retrieval/dssm.py +0 -0
  127. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/retrieval/dssm_v2.py +0 -0
  128. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/retrieval/mind.py +0 -0
  129. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/retrieval/sdm.py +0 -0
  130. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/retrieval/youtube_dnn.py +0 -0
  131. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/sequential/hstu.py +0 -0
  132. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/tree_base/__init__.py +0 -0
  133. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/tree_base/catboost.py +0 -0
  134. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/tree_base/lightgbm.py +0 -0
  135. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/models/tree_base/xgboost.py +0 -0
  136. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/utils/data.py +0 -0
  137. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/utils/embedding.py +0 -0
  138. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/utils/loss.py +0 -0
  139. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec/utils/model.py +0 -0
  140. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec_cli_preset/feature_config.yaml +0 -0
  141. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec_cli_preset/model_configs/afm.yaml +0 -0
  142. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec_cli_preset/model_configs/apg.yaml +0 -0
  143. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec_cli_preset/model_configs/autoint.yaml +0 -0
  144. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec_cli_preset/model_configs/cross_stitch.yaml +0 -0
  145. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec_cli_preset/model_configs/dcn.yaml +0 -0
  146. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec_cli_preset/model_configs/deepfm.yaml +0 -0
  147. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec_cli_preset/model_configs/din.yaml +0 -0
  148. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec_cli_preset/model_configs/escm.yaml +0 -0
  149. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec_cli_preset/model_configs/esmm.yaml +0 -0
  150. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec_cli_preset/model_configs/fibinet.yaml +0 -0
  151. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec_cli_preset/model_configs/fm.yaml +0 -0
  152. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec_cli_preset/model_configs/hmoe.yaml +0 -0
  153. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec_cli_preset/model_configs/masknet.yaml +0 -0
  154. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec_cli_preset/model_configs/mmoe.yaml +0 -0
  155. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec_cli_preset/model_configs/pepnet.yaml +0 -0
  156. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec_cli_preset/model_configs/ple.yaml +0 -0
  157. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec_cli_preset/model_configs/pnn.yaml +0 -0
  158. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec_cli_preset/model_configs/poso.yaml +0 -0
  159. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec_cli_preset/model_configs/share_bottom.yaml +0 -0
  160. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec_cli_preset/model_configs/widedeep.yaml +0 -0
  161. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec_cli_preset/model_configs/xdeepfm.yaml +0 -0
  162. {nextrec-0.4.33 → nextrec-0.4.34}/nextrec_cli_preset/train_config.yaml +0 -0
  163. {nextrec-0.4.33 → nextrec-0.4.34}/pytest.ini +0 -0
  164. {nextrec-0.4.33 → nextrec-0.4.34}/requirements.txt +0 -0
  165. {nextrec-0.4.33 → nextrec-0.4.34}/scripts/format_code.py +0 -0
  166. {nextrec-0.4.33 → nextrec-0.4.34}/test/__init__.py +0 -0
  167. {nextrec-0.4.33 → nextrec-0.4.34}/test/conftest.py +0 -0
  168. {nextrec-0.4.33 → nextrec-0.4.34}/test/helpers.py +0 -0
  169. {nextrec-0.4.33 → nextrec-0.4.34}/test/test_base_model_regularization.py +0 -0
  170. {nextrec-0.4.33 → nextrec-0.4.34}/test/test_generative_models.py +0 -0
  171. {nextrec-0.4.33 → nextrec-0.4.34}/test/test_layers.py +0 -0
  172. {nextrec-0.4.33 → nextrec-0.4.34}/test/test_losses.py +0 -0
  173. {nextrec-0.4.33 → nextrec-0.4.34}/test/test_match_models.py +0 -0
  174. {nextrec-0.4.33 → nextrec-0.4.34}/test/test_multitask_models.py +0 -0
  175. {nextrec-0.4.33 → nextrec-0.4.34}/test/test_preprocessor.py +0 -0
  176. {nextrec-0.4.33 → nextrec-0.4.34}/test/test_ranking_models.py +0 -0
  177. {nextrec-0.4.33 → nextrec-0.4.34}/test/test_utils_console.py +0 -0
  178. {nextrec-0.4.33 → nextrec-0.4.34}/test/test_utils_data.py +0 -0
  179. {nextrec-0.4.33 → nextrec-0.4.34}/test/test_utils_embedding.py +0 -0
  180. {nextrec-0.4.33 → nextrec-0.4.34}/test_requirements.txt +0 -0
  181. {nextrec-0.4.33 → nextrec-0.4.34}/tutorials/distributed/example_distributed_training.py +0 -0
  182. {nextrec-0.4.33 → nextrec-0.4.34}/tutorials/distributed/example_distributed_training_large_dataset.py +0 -0
  183. {nextrec-0.4.33 → nextrec-0.4.34}/tutorials/example_ranking_din.py +0 -0
  184. {nextrec-0.4.33 → nextrec-0.4.34}/tutorials/example_tree.py +0 -0
  185. {nextrec-0.4.33 → nextrec-0.4.34}/tutorials/movielen_match_dssm.py +0 -0
  186. {nextrec-0.4.33 → nextrec-0.4.34}/tutorials/notebooks/en/Build semantic ID with RQ-VAE.ipynb +0 -0
  187. {nextrec-0.4.33 → nextrec-0.4.34}/tutorials/notebooks/en/Hands on dataprocessor.ipynb +0 -0
  188. {nextrec-0.4.33 → nextrec-0.4.34}/tutorials/notebooks/en/Hands on nextrec.ipynb +0 -0
  189. {nextrec-0.4.33 → nextrec-0.4.34}/tutorials/notebooks/zh//344/275/277/347/224/250RQ-VAE/346/236/204/345/273/272/350/257/255/344/271/211ID.ipynb" +0 -0
  190. {nextrec-0.4.33 → nextrec-0.4.34}/tutorials/notebooks/zh//345/246/202/344/275/225/344/275/277/347/224/250DataProcessor/350/277/233/350/241/214/351/242/204/345/244/204/347/220/206.ipynb" +0 -0
  191. {nextrec-0.4.33 → nextrec-0.4.34}/tutorials/notebooks/zh//345/277/253/351/200/237/345/205/245/351/227/250nextrec.ipynb" +0 -0
  192. {nextrec-0.4.33 → nextrec-0.4.34}/tutorials/run_all_match_models.py +0 -0
  193. {nextrec-0.4.33 → nextrec-0.4.34}/tutorials/run_all_multitask_models.py +0 -0
  194. {nextrec-0.4.33 → nextrec-0.4.34}/tutorials/run_all_ranking_models.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nextrec
3
- Version: 0.4.33
3
+ Version: 0.4.34
4
4
  Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
5
5
  Project-URL: Homepage, https://github.com/zerolovesea/NextRec
6
6
  Project-URL: Repository, https://github.com/zerolovesea/NextRec
@@ -69,7 +69,7 @@ Description-Content-Type: text/markdown
69
69
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
70
70
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
71
71
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
72
- ![Version](https://img.shields.io/badge/Version-0.4.33-orange.svg)
72
+ ![Version](https://img.shields.io/badge/Version-0.4.34-orange.svg)
73
73
  [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/zerolovesea/NextRec)
74
74
 
75
75
  中文文档 | [English Version](README_en.md)
@@ -254,11 +254,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
254
254
 
255
255
  预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
256
256
 
257
- > 截止当前版本0.4.33,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
257
+ > 截止当前版本0.4.34,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
258
258
 
259
259
  ## 兼容平台
260
260
 
261
- 当前最新版本为0.4.33,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
261
+ 当前最新版本为0.4.34,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
262
262
 
263
263
  | 平台 | 配置 |
264
264
  |------|------|
@@ -8,7 +8,7 @@
8
8
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
9
9
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
10
10
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
11
- ![Version](https://img.shields.io/badge/Version-0.4.33-orange.svg)
11
+ ![Version](https://img.shields.io/badge/Version-0.4.34-orange.svg)
12
12
  [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/zerolovesea/NextRec)
13
13
 
14
14
  中文文档 | [English Version](README_en.md)
@@ -193,11 +193,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
193
193
 
194
194
  预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
195
195
 
196
- > 截止当前版本0.4.33,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
196
+ > 截止当前版本0.4.34,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
197
197
 
198
198
  ## 兼容平台
199
199
 
200
- 当前最新版本为0.4.33,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
200
+ 当前最新版本为0.4.34,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
201
201
 
202
202
  | 平台 | 配置 |
203
203
  |------|------|
@@ -8,7 +8,7 @@
8
8
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
9
9
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
10
10
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
11
- ![Version](https://img.shields.io/badge/Version-0.4.33-orange.svg)
11
+ ![Version](https://img.shields.io/badge/Version-0.4.34-orange.svg)
12
12
  [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/zerolovesea/NextRec)
13
13
 
14
14
  English | [中文文档](README.md)
@@ -196,11 +196,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
196
196
 
197
197
  Prediction outputs are saved under `{checkpoint_path}/predictions/{name}.{save_data_format}`.
198
198
 
199
- > As of version 0.4.33, NextRec CLI supports single-machine training; distributed training features are currently under development.
199
+ > As of version 0.4.34, NextRec CLI supports single-machine training; distributed training features are currently under development.
200
200
 
201
201
  ## Platform Compatibility
202
202
 
203
- The current version is 0.4.33. All models and test code have been validated on the following platforms. If you encounter compatibility issues, please report them in the issue tracker with your system version:
203
+ The current version is 0.4.34. All models and test code have been validated on the following platforms. If you encounter compatibility issues, please report them in the issue tracker with your system version:
204
204
 
205
205
  | Platform | Configuration |
206
206
  |----------|---------------|
@@ -11,7 +11,7 @@ sys.path.insert(0, str(PROJECT_ROOT / "nextrec"))
11
11
  project = "NextRec"
12
12
  copyright = "2026, Yang Zhou"
13
13
  author = "Yang Zhou"
14
- release = "0.4.33"
14
+ release = "0.4.34"
15
15
 
16
16
  extensions = [
17
17
  "myst_parser",
@@ -28,14 +28,6 @@ nextrec.utils.data module
28
28
  :undoc-members:
29
29
  :show-inheritance:
30
30
 
31
- nextrec.utils.feature module
32
- ----------------------------
33
-
34
- .. automodule:: nextrec.utils.feature
35
- :members:
36
- :undoc-members:
37
- :show-inheritance:
38
-
39
31
  nextrec.utils.model module
40
32
  --------------------------
41
33
 
@@ -0,0 +1 @@
1
+ __version__ = "0.4.34"
@@ -1,8 +1,8 @@
1
1
  """
2
- Activation function definitions for NextRec models.
2
+ Activation function definitions.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 28/12/2025
5
+ Checkpoint: edit on 20/01/2026
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -22,26 +22,24 @@ class Dice(nn.Module):
22
22
  where p(x) = sigmoid((x - E[x]) / sqrt(Var[x] + epsilon))
23
23
  """
24
24
 
25
- def __init__(self, emb_size: int, epsilon: float = 1e-9):
25
+ def __init__(self, emb_size: int, epsilon: float = 1e-3):
26
26
  super(Dice, self).__init__()
27
- self.epsilon = epsilon
28
27
  self.alpha = nn.Parameter(torch.zeros(emb_size))
29
- self.bn = nn.BatchNorm1d(emb_size)
28
+ self.bn = nn.BatchNorm1d(emb_size, eps=epsilon)
30
29
 
31
30
  def forward(self, x):
32
31
  # x shape: (batch_size, emb_size) or (batch_size, seq_len, emb_size)
33
- original_shape = x.shape
32
+ if x.dim() == 2: # (B, E)
33
+ x_norm = self.bn(x)
34
+ p = torch.sigmoid(x_norm)
35
+ return x * (self.alpha + (1 - self.alpha) * p)
34
36
 
35
- if x.dim() == 3:
36
- # For 3D input (batch_size, seq_len, emb_size), reshape to 2D
37
- batch_size, seq_len, emb_size = x.shape
38
- x = x.view(-1, emb_size)
39
- x_norm = self.bn(x)
40
- p = torch.sigmoid(x_norm)
41
- output = p * x + (1 - p) * self.alpha * x
42
- if len(original_shape) == 3:
43
- output = output.view(original_shape)
44
- return output
37
+ if x.dim() == 3: # (B, T, E)
38
+ b, t, e = x.shape
39
+ x2 = x.reshape(-1, e) # (B*T, E)
40
+ x_norm = self.bn(x2)
41
+ p = torch.sigmoid(x_norm).reshape(b, t, e)
42
+ return x * (self.alpha + (1 - self.alpha) * p)
45
43
 
46
44
 
47
45
  def activation_layer(
@@ -8,7 +8,7 @@ Author: Yang Zhou, zyaztec@gmail.com
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
- from nextrec.utils.types import TaskTypeName, TrainingModeName
11
+ from nextrec.utils.types import TaskTypeName
12
12
 
13
13
 
14
14
  def assert_task(
@@ -49,24 +49,3 @@ def assert_task(
49
49
  raise ValueError(
50
50
  f"{model_name} requires task length {nums_task}, got {len(task)}."
51
51
  )
52
-
53
-
54
- def assert_training_mode(
55
- training_mode: TrainingModeName | list[TrainingModeName],
56
- nums_task: int,
57
- *,
58
- model_name: str,
59
- ) -> None:
60
- valid_modes = {"pointwise", "pairwise", "listwise"}
61
- if not isinstance(training_mode, list):
62
- raise TypeError(
63
- f"[{model_name}-init Error] training_mode must be a list with length {nums_task}."
64
- )
65
- if len(training_mode) != nums_task:
66
- raise ValueError(
67
- f"[{model_name}-init Error] training_mode list length must match number of tasks."
68
- )
69
- if any(mode not in valid_modes for mode in training_mode):
70
- raise ValueError(
71
- f"[{model_name}-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
72
- )
@@ -2,7 +2,7 @@
2
2
  Callback System for Training Process
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 27/12/2025
5
+ Checkpoint: edit on 21/01/2026
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -69,7 +69,7 @@ class Callback:
69
69
 
70
70
  class CallbackList:
71
71
  """
72
- Generates a list of callbacks
72
+ Generates a list of callbacks, used to manage and invoke multiple callbacks during training.
73
73
  """
74
74
 
75
75
  def __init__(self, callbacks: Optional[list[Callback]] = None):
@@ -8,10 +8,9 @@ Author: Yang Zhou, zyaztec@gmail.com
8
8
 
9
9
  import torch
10
10
 
11
- from typing import Literal
12
-
13
11
  from nextrec.utils.embedding import get_auto_embedding_dim
14
- from nextrec.utils.feature import to_list
12
+ from nextrec.utils.torch_utils import to_list
13
+ from nextrec.utils.types import EmbeddingInitType, SequenceCombinerType
15
14
 
16
15
 
17
16
  class BaseFeature:
@@ -29,15 +28,7 @@ class EmbeddingFeature(BaseFeature):
29
28
  embedding_name: str = "",
30
29
  embedding_dim: int | None = None,
31
30
  padding_idx: int = 0,
32
- init_type: Literal[
33
- "normal",
34
- "uniform",
35
- "xavier_uniform",
36
- "xavier_normal",
37
- "kaiming_uniform",
38
- "kaiming_normal",
39
- "orthogonal",
40
- ] = "normal",
31
+ init_type: EmbeddingInitType = "normal",
41
32
  init_params: dict | None = None,
42
33
  l1_reg: float = 0.0,
43
34
  l2_reg: float = 0.0,
@@ -73,23 +64,9 @@ class SequenceFeature(EmbeddingFeature):
73
64
  max_len: int = 50,
74
65
  embedding_name: str = "",
75
66
  embedding_dim: int | None = None,
76
- combiner: Literal[
77
- "mean",
78
- "sum",
79
- "concat",
80
- "dot_attention",
81
- "self_attention",
82
- ] = "mean",
67
+ combiner: SequenceCombinerType = "mean",
83
68
  padding_idx: int = 0,
84
- init_type: Literal[
85
- "normal",
86
- "uniform",
87
- "xavier_uniform",
88
- "xavier_normal",
89
- "kaiming_uniform",
90
- "kaiming_normal",
91
- "orthogonal",
92
- ] = "normal",
69
+ init_type: EmbeddingInitType = "normal",
93
70
  init_params: dict | None = None,
94
71
  l1_reg: float = 0.0,
95
72
  l2_reg: float = 0.0,
@@ -143,15 +120,7 @@ class SparseFeature(EmbeddingFeature):
143
120
  embedding_name: str = "",
144
121
  embedding_dim: int | None = None,
145
122
  padding_idx: int = 0,
146
- init_type: Literal[
147
- "normal",
148
- "uniform",
149
- "xavier_uniform",
150
- "xavier_normal",
151
- "kaiming_uniform",
152
- "kaiming_normal",
153
- "orthogonal",
154
- ] = "normal",
123
+ init_type: EmbeddingInitType = "normal",
155
124
  init_params: dict | None = None,
156
125
  l1_reg: float = 0.0,
157
126
  l2_reg: float = 0.0,
@@ -2,7 +2,7 @@
2
2
  Task head implementations for NextRec models.
3
3
 
4
4
  Date: create on 23/12/2025
5
- Checkpoint: edit on 27/12/2025
5
+ Checkpoint: edit on 22/01/2026
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -24,6 +24,12 @@ class TaskHead(nn.Module):
24
24
 
25
25
  This wraps PredictionLayer so models can depend on a "Head" abstraction
26
26
  without changing their existing forward signatures.
27
+
28
+ Args:
29
+ task_type: The type of task(s) this head is responsible for.
30
+ task_dims: The dimensionality of each task's output.
31
+ use_bias: Whether to include a bias term in the prediction layer.
32
+ return_logits: Whether to return raw logits or apply activation.
27
33
  """
28
34
 
29
35
  def __init__(
@@ -56,6 +62,12 @@ class RetrievalHead(nn.Module):
56
62
 
57
63
  It computes similarity for pointwise training/inference, and returns
58
64
  raw embeddings for in-batch negative sampling in pairwise/listwise modes.
65
+
66
+ Args:
67
+ similarity_metric: The metric used to compute similarity between embeddings.
68
+ temperature: Scaling factor for similarity scores.
69
+ training_mode: The training mode, which can be pointwise, pairwise, or listwise.
70
+ apply_sigmoid: Whether to apply sigmoid activation to the similarity scores in pointwise mode.
59
71
  """
60
72
 
61
73
  def __init__(
@@ -2,7 +2,7 @@
2
2
  Layer implementations used across NextRec.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 27/12/2025
5
+ Checkpoint: edit on 22/01/2026
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -20,15 +20,13 @@ import torch.nn.functional as F
20
20
  from nextrec.basic.activation import activation_layer
21
21
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
22
22
  from nextrec.utils.torch_utils import get_initializer
23
- from nextrec.utils.types import ActivationName
23
+ from nextrec.utils.types import ActivationName, TaskTypeName
24
24
 
25
25
 
26
26
  class PredictionLayer(nn.Module):
27
27
  def __init__(
28
28
  self,
29
- task_type: (
30
- Literal["binary", "regression"] | list[Literal["binary", "regression"]]
31
- ) = "binary",
29
+ task_type: TaskTypeName | list[TaskTypeName] = "binary",
32
30
  task_dims: int | list[int] | None = None,
33
31
  use_bias: bool = True,
34
32
  return_logits: bool = False,
@@ -92,10 +90,9 @@ class PredictionLayer(nn.Module):
92
90
  if self.return_logits:
93
91
  outputs.append(task_logits)
94
92
  continue
95
- task = task_type.lower()
96
- if task == "binary":
93
+ if task_type == "binary":
97
94
  outputs.append(torch.sigmoid(task_logits))
98
- elif task == "regression":
95
+ elif task_type == "regression":
99
96
  outputs.append(task_logits)
100
97
  else:
101
98
  raise ValueError(
@@ -897,30 +894,7 @@ class AttentionPoolingLayer(nn.Module):
897
894
  self,
898
895
  embedding_dim: int,
899
896
  hidden_units: list = [80, 40],
900
- activation: Literal[
901
- "dice",
902
- "relu",
903
- "relu6",
904
- "elu",
905
- "selu",
906
- "leaky_relu",
907
- "prelu",
908
- "gelu",
909
- "sigmoid",
910
- "tanh",
911
- "softplus",
912
- "softsign",
913
- "hardswish",
914
- "mish",
915
- "silu",
916
- "swish",
917
- "hardsigmoid",
918
- "tanhshrink",
919
- "softshrink",
920
- "none",
921
- "linear",
922
- "identity",
923
- ] = "sigmoid",
897
+ activation: ActivationName = "sigmoid",
924
898
  use_softmax: bool = False,
925
899
  ):
926
900
  super().__init__()
@@ -1029,7 +1003,9 @@ class RMSNorm(torch.nn.Module):
1029
1003
 
1030
1004
 
1031
1005
  class DomainBatchNorm(nn.Module):
1032
- """Domain-specific BatchNorm (applied per-domain with a shared interface)."""
1006
+ """
1007
+ Domain-specific BatchNorm (applied per-domain with a shared interface).
1008
+ """
1033
1009
 
1034
1010
  def __init__(self, num_features: int, num_domains: int):
1035
1011
  super().__init__()
@@ -2,7 +2,7 @@
2
2
  NextRec Basic Loggers
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 01/01/2026
5
+ Checkpoint: edit on 22/01/2026
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -99,7 +99,8 @@ def format_kv(label: str, value: Any, width: int = 34, indent: int = 0) -> str:
99
99
 
100
100
 
101
101
  def setup_logger(session_id: str | os.PathLike | None = None):
102
- """Set up a logger that logs to both console and a file with ANSI formatting.
102
+ """
103
+ Set up a logger that logs to both console and a file with ANSI formatting.
103
104
  Only console output has colors; file output is stripped of ANSI codes.
104
105
 
105
106
  Logs are stored under ``log/<experiment_id>/logs`` by default. A stable
@@ -23,7 +23,6 @@ from sklearn.metrics import (
23
23
  )
24
24
  from nextrec.utils.types import TaskTypeName, MetricsName
25
25
 
26
-
27
26
  TASK_DEFAULT_METRICS = {
28
27
  "binary": ["auc", "gauc", "ks", "logloss", "accuracy", "precision", "recall", "f1"],
29
28
  "regression": ["mse", "mae", "rmse", "r2", "mape"],
@@ -334,6 +333,60 @@ def compute_map_at_k(
334
333
  return float(np.mean(aps)) if aps else 0.0
335
334
 
336
335
 
336
+ def compute_topk_counts(
337
+ y_true: np.ndarray, y_pred: np.ndarray, k_percent: int
338
+ ) -> tuple[int, int, int]:
339
+ """Compute Top-K% sample size, hits, and positives for binary labels."""
340
+ y_true = (y_true > 0).astype(int)
341
+ n = y_true.size
342
+ if n == 0:
343
+ return 0, 0, 0
344
+ if k_percent <= 0:
345
+ return 0, 0, int(y_true.sum())
346
+ if k_percent >= 100:
347
+ k_count = n
348
+ else:
349
+ k_count = int(np.ceil(n * (k_percent / 100.0)))
350
+ k_count = max(k_count, 1)
351
+ order = np.argsort(y_pred)[::-1]
352
+ topk = order[:k_count]
353
+ hits = int(y_true[topk].sum())
354
+ total_pos = int(y_true.sum())
355
+ return k_count, hits, total_pos
356
+
357
+
358
+ def compute_topk_precision(
359
+ y_true: np.ndarray, y_pred: np.ndarray, k_percent: int
360
+ ) -> float:
361
+ """Compute Top-K% Precision."""
362
+ k_count, hits, _ = compute_topk_counts(y_true, y_pred, k_percent)
363
+ if k_count == 0:
364
+ return 0.0
365
+ return float(hits / k_count)
366
+
367
+
368
+ def compute_topk_recall(
369
+ y_true: np.ndarray, y_pred: np.ndarray, k_percent: int
370
+ ) -> float:
371
+ """Compute Top-K% Recall."""
372
+ _, hits, total_pos = compute_topk_counts(y_true, y_pred, k_percent)
373
+ if total_pos == 0:
374
+ return 0.0
375
+ return float(hits / total_pos)
376
+
377
+
378
+ def compute_lift_at_k(y_true: np.ndarray, y_pred: np.ndarray, k_percent: int) -> float:
379
+ """Compute Lift@K from Top-K% precision and overall rate."""
380
+ k_count, hits, total_pos = compute_topk_counts(y_true, y_pred, k_percent)
381
+ if k_count == 0:
382
+ return 0.0
383
+ base_rate = total_pos / float(y_true.size)
384
+ if base_rate == 0.0:
385
+ return 0.0
386
+ precision = hits / float(k_count)
387
+ return float(precision / base_rate)
388
+
389
+
337
390
  def compute_cosine_separation(y_true: np.ndarray, y_pred: np.ndarray) -> float:
338
391
  """Compute Cosine Separation."""
339
392
  y_true = (y_true > 0).astype(int)
@@ -399,11 +452,11 @@ def configure_metrics(
399
452
  if primary_task not in TASK_DEFAULT_METRICS:
400
453
  raise ValueError(f"Unsupported task type: {primary_task}")
401
454
  metrics_list = TASK_DEFAULT_METRICS[primary_task]
402
- best_metrics_mode = getbest_metric_mode(metrics_list[0], primary_task)
455
+ best_metrics_mode = get_best_metric_mode(metrics_list[0], primary_task)
403
456
  return metrics_list, task_specific_metrics, best_metrics_mode
404
457
 
405
458
 
406
- def getbest_metric_mode(first_metric: MetricsName, primary_task: TaskTypeName) -> str:
459
+ def get_best_metric_mode(first_metric: MetricsName, primary_task: TaskTypeName) -> str:
407
460
  """Determine if metric should be maximized or minimized."""
408
461
  # Metrics that should be maximized
409
462
  if first_metric in {
@@ -429,6 +482,9 @@ def getbest_metric_mode(first_metric: MetricsName, primary_task: TaskTypeName) -
429
482
  or first_metric.startswith("mrr@")
430
483
  or first_metric.startswith("ndcg@")
431
484
  or first_metric.startswith("map@")
485
+ or first_metric.startswith("topk_recall@")
486
+ or first_metric.startswith("topk_precision@")
487
+ or first_metric.startswith("lift@")
432
488
  ):
433
489
  return "max"
434
490
  # Cosine separation should be maximized
@@ -457,6 +513,15 @@ def compute_single_metric(
457
513
 
458
514
  y_p_binary = (y_pred > 0.5).astype(int)
459
515
  try:
516
+ if metric.startswith("topk_recall@"):
517
+ k_percent = int(metric.split("@")[1])
518
+ return compute_topk_recall(y_true, y_pred, k_percent)
519
+ if metric.startswith("topk_precision@"):
520
+ k_percent = int(metric.split("@")[1])
521
+ return compute_topk_precision(y_true, y_pred, k_percent)
522
+ if metric.startswith("lift@"):
523
+ k_percent = int(metric.split("@")[1])
524
+ return compute_lift_at_k(y_true, y_pred, k_percent)
460
525
  if metric.startswith("recall@"):
461
526
  k = int(metric.split("@")[1])
462
527
  return compute_recall_at_k(y_true, y_pred, user_ids, k) # type: ignore
@@ -650,7 +715,23 @@ def evaluate_metrics(
650
715
  allowed_metrics = metric_allowlist.get(task_type)
651
716
  for metric in metrics:
652
717
  if allowed_metrics is not None and metric not in allowed_metrics:
653
- continue
718
+ if metric.startswith(
719
+ (
720
+ "recall@",
721
+ "precision@",
722
+ "hitrate@",
723
+ "hr@",
724
+ "mrr@",
725
+ "ndcg@",
726
+ "map@",
727
+ "topk_recall@",
728
+ "topk_precision@",
729
+ "lift@",
730
+ )
731
+ ):
732
+ pass
733
+ else:
734
+ continue
654
735
  y_true_task = y_true[:, task_idx]
655
736
  y_pred_task = y_pred[:, task_idx]
656
737
  task_user_ids = user_ids
@@ -2,7 +2,7 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 01/01/2026
5
+ Checkpoint: edit on 22/01/2026
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
@@ -155,9 +155,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
155
155
  session_id: Session id for logging. If None, a default id with timestamps will be created. e.g., 'session_tutorial'.
156
156
 
157
157
  distributed: Enable DistributedDataParallel flow, set True to enable distributed training.
158
- rank: Global rank (defaults to env RANK).
159
- world_size: Number of processes (defaults to env WORLD_SIZE).
160
- local_rank: Local rank for selecting CUDA device (defaults to env LOCAL_RANK).
158
+ rank: Global rank (defaults to env RANK). e.g., 0 for the main process.
159
+ world_size: Number of processes (defaults to env WORLD_SIZE). e.g., 4 for a 4-process training.
160
+ local_rank: Local rank for selecting CUDA device (defaults to env LOCAL_RANK). e.g., 0 for the first GPU.
161
161
  ddp_find_unused_parameters: Default False, set it True only when exist unused parameters in ddp model, in most cases should be False.
162
162
 
163
163
  Note:
@@ -1351,6 +1351,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
1351
1351
  nn.utils.clip_grad_norm_(params, self.max_gradient_norm)
1352
1352
  self.optimizer_fn.step()
1353
1353
  if self.grad_norm is not None:
1354
+ # Synchronize GradNorm buffers across DDP ranks before stepping
1355
+ if self.distributed and dist.is_available() and dist.is_initialized():
1356
+ self.grad_norm.sync()
1354
1357
  self.grad_norm.step()
1355
1358
  accumulated_loss += loss.item()
1356
1359