nextrec 0.4.11__tar.gz → 0.4.12__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nextrec-0.4.11 → nextrec-0.4.12}/PKG-INFO +6 -6
- {nextrec-0.4.11 → nextrec-0.4.12}/README.md +5 -5
- {nextrec-0.4.11 → nextrec-0.4.12}/README_en.md +4 -4
- {nextrec-0.4.11 → nextrec-0.4.12}/docs/rtd/conf.py +1 -1
- nextrec-0.4.12/nextrec/__version__.py +1 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/basic/callback.py +44 -54
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/basic/features.py +35 -22
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/basic/layers.py +64 -68
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/basic/loggers.py +2 -2
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/basic/metrics.py +9 -5
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/basic/model.py +162 -106
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/cli.py +16 -5
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/data/preprocessor.py +4 -4
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/loss/loss_utils.py +1 -1
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/ranking/eulernet.py +44 -75
- nextrec-0.4.12/nextrec/models/ranking/ffm.py +275 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/ranking/lr.py +1 -3
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/utils/__init__.py +2 -1
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/utils/console.py +9 -1
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/utils/model.py +14 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/pyproject.toml +1 -1
- {nextrec-0.4.11 → nextrec-0.4.12}/test/test_ranking_models.py +37 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/tutorials/movielen_match_dssm.py +2 -1
- {nextrec-0.4.11 → nextrec-0.4.12}/tutorials/run_all_match_models.py +4 -1
- {nextrec-0.4.11 → nextrec-0.4.12}/tutorials/run_all_ranking_models.py +2 -0
- nextrec-0.4.11/nextrec/__version__.py +0 -1
- nextrec-0.4.11/nextrec/models/ranking/ffm.py +0 -0
- nextrec-0.4.11/tutorials/example_match_dssm.py +0 -130
- {nextrec-0.4.11 → nextrec-0.4.12}/.github/workflows/publish.yml +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/.github/workflows/tests.yml +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/.gitignore +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/.readthedocs.yaml +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/CODE_OF_CONDUCT.md +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/CONTRIBUTING.md +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/LICENSE +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/MANIFEST.in +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/assets/Feature Configuration.png +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/assets/Model Parameters.png +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/assets/Training Configuration.png +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/assets/Training logs.png +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/assets/logo.png +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/assets/mmoe_tutorial.png +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/assets/nextrec_diagram.png +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/assets/test data.png +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/dataset/ctcvr_task.csv +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/dataset/ecommerce_task.csv +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/dataset/match_task.csv +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/dataset/movielens_100k.csv +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/dataset/multitask_task.csv +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/dataset/ranking_task.csv +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/docs/en/Getting started guide.md +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/docs/rtd/Makefile +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/docs/rtd/index.md +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/docs/rtd/make.bat +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/docs/rtd/modules.rst +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/docs/rtd/nextrec.basic.rst +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/docs/rtd/nextrec.data.rst +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/docs/rtd/nextrec.loss.rst +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/docs/rtd/nextrec.rst +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/docs/rtd/nextrec.utils.rst +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/docs/rtd/requirements.txt +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/docs/zh//345/277/253/351/200/237/344/270/212/346/211/213.md" +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/__init__.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/basic/__init__.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/basic/activation.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/basic/session.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/data/__init__.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/data/batch_utils.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/data/data_processing.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/data/data_utils.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/data/dataloader.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/loss/__init__.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/loss/listwise.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/loss/pairwise.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/loss/pointwise.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/generative/__init__.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/generative/tiger.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/multi_task/__init__.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/multi_task/esmm.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/multi_task/mmoe.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/multi_task/ple.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/multi_task/poso.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/multi_task/share_bottom.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/ranking/__init__.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/ranking/afm.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/ranking/autoint.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/ranking/dcn.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/ranking/dcn_v2.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/ranking/deepfm.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/ranking/dien.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/ranking/din.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/ranking/fibinet.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/ranking/fm.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/ranking/masknet.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/ranking/pnn.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/ranking/widedeep.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/ranking/xdeepfm.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/representation/__init__.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/representation/autorec.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/representation/bpr.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/representation/cl4srec.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/representation/lightgcn.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/representation/mf.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/representation/rqvae.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/representation/s3rec.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/retrieval/__init__.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/retrieval/dssm.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/retrieval/dssm_v2.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/retrieval/mind.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/retrieval/sdm.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/retrieval/youtube_dnn.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/sequential/hstu.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/models/sequential/sasrec.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/utils/config.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/utils/data.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/utils/embedding.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/utils/feature.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec/utils/torch_utils.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec_cli_preset/NextRec-CLI.md +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec_cli_preset/NextRec-CLI_zh.md +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec_cli_preset/feature_config.yaml +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec_cli_preset/model_configs/afm.yaml +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec_cli_preset/model_configs/autoint.yaml +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec_cli_preset/model_configs/dcn.yaml +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec_cli_preset/model_configs/deepfm.yaml +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec_cli_preset/model_configs/din.yaml +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec_cli_preset/model_configs/esmm.yaml +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec_cli_preset/model_configs/fibinet.yaml +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec_cli_preset/model_configs/fm.yaml +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec_cli_preset/model_configs/masknet.yaml +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec_cli_preset/model_configs/mmoe.yaml +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec_cli_preset/model_configs/ple.yaml +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec_cli_preset/model_configs/pnn.yaml +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec_cli_preset/model_configs/poso.yaml +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec_cli_preset/model_configs/share_bottom.yaml +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec_cli_preset/model_configs/widedeep.yaml +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec_cli_preset/model_configs/xdeepfm.yaml +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec_cli_preset/predict_config.yaml +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec_cli_preset/predict_config_template.yaml +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec_cli_preset/train_config.yaml +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/nextrec_cli_preset/train_config_template.yaml +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/pytest.ini +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/requirements.txt +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/scripts/format_code.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/test/__init__.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/test/conftest.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/test/helpers.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/test/run_tests.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/test/test_base_model_regularization.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/test/test_generative_models.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/test/test_layers.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/test/test_losses.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/test/test_match_models.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/test/test_multitask_models.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/test/test_preprocessor.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/test/test_utils_console.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/test/test_utils_data.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/test/test_utils_embedding.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/test_requirements.txt +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/tutorials/distributed/example_distributed_training.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/tutorials/distributed/example_distributed_training_large_dataset.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/tutorials/example_multitask.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/tutorials/example_ranking_din.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/tutorials/movielen_ranking_deepfm.py +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/tutorials/notebooks/en/Build semantic ID with RQ-VAE.ipynb +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/tutorials/notebooks/en/Hands on dataprocessor.ipynb +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/tutorials/notebooks/en/Hands on nextrec.ipynb +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/tutorials/notebooks/zh//344/275/277/347/224/250RQ-VAE/346/236/204/345/273/272/350/257/255/344/271/211ID.ipynb" +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/tutorials/notebooks/zh//345/246/202/344/275/225/344/275/277/347/224/250DataProcessor/350/277/233/350/241/214/351/242/204/345/244/204/347/220/206.ipynb" +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/tutorials/notebooks/zh//345/277/253/351/200/237/345/205/245/351/227/250nextrec.ipynb" +0 -0
- {nextrec-0.4.11 → nextrec-0.4.12}/tutorials/run_all_multitask_models.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nextrec
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.12
|
|
4
4
|
Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
|
|
5
5
|
Project-URL: Homepage, https://github.com/zerolovesea/NextRec
|
|
6
6
|
Project-URL: Repository, https://github.com/zerolovesea/NextRec
|
|
@@ -66,7 +66,7 @@ Description-Content-Type: text/markdown
|
|
|
66
66
|

|
|
67
67
|

|
|
68
68
|

|
|
69
|
-

|
|
70
70
|
|
|
71
71
|
中文文档 | [English Version](README_en.md)
|
|
72
72
|
|
|
@@ -99,7 +99,7 @@ NextRec是一个基于PyTorch的现代推荐系统框架,旨在为研究工程
|
|
|
99
99
|
|
|
100
100
|
## NextRec近期进展
|
|
101
101
|
|
|
102
|
-
- **12/12/2025** 在v0.4.
|
|
102
|
+
- **12/12/2025** 在v0.4.12中加入了[RQ-VAE](/nextrec/models/representation/rqvae.py)模块。配套的[数据集](/dataset/ecommerce_task.csv)和[代码](tutorials/notebooks/zh/使用RQ-VAE构建语义ID.ipynb)已经同步在仓库中
|
|
103
103
|
- **07/12/2025** 发布了NextRec CLI命令行工具,它允许用户根据配置文件进行一键训练和推理,我们提供了相关的[教程](/nextrec_cli_preset/NextRec-CLI_zh.md)和[教学代码](/nextrec_cli_preset)
|
|
104
104
|
- **03/12/2025** NextRec获得了100颗🌟!感谢大家的支持
|
|
105
105
|
- **06/12/2025** 在v0.4.1中支持了单机多卡的分布式DDP训练,并且提供了配套的[代码](tutorials/distributed)
|
|
@@ -240,11 +240,11 @@ nextrec --mode=train --train_config=path/to/train_config.yaml
|
|
|
240
240
|
nextrec --mode=predict --predict_config=path/to/predict_config.yaml
|
|
241
241
|
```
|
|
242
242
|
|
|
243
|
-
> 截止当前版本0.4.
|
|
243
|
+
> 截止当前版本0.4.12,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
|
|
244
244
|
|
|
245
245
|
## 兼容平台
|
|
246
246
|
|
|
247
|
-
当前最新版本为0.4.
|
|
247
|
+
当前最新版本为0.4.12,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
|
|
248
248
|
|
|
249
249
|
| 平台 | 配置 |
|
|
250
250
|
|------|------|
|
|
@@ -262,7 +262,7 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
|
|
|
262
262
|
| [FM](nextrec/models/ranking/fm.py) | Factorization Machines | ICDM 2010 | 已支持 |
|
|
263
263
|
| [LR](nextrec/models/ranking/lr.py) | Logistic Regression | - | 已支持 |
|
|
264
264
|
| [AFM](nextrec/models/ranking/afm.py) | Attentional Factorization Machines: Learning the Weight of Feature Interactions via Attention Networks | IJCAI 2017 | 已支持 |
|
|
265
|
-
| [FFM](nextrec/models/ranking/ffm.py) | Field-aware Factorization Machines | RecSys
|
|
265
|
+
| [FFM](nextrec/models/ranking/ffm.py) | Field-aware Factorization Machines | RecSys 2016 | 已支持 |
|
|
266
266
|
| [DeepFM](nextrec/models/ranking/deepfm.py) | DeepFM: A Factorization-Machine based Neural Network for CTR Prediction | IJCAI 2017 | 已支持 |
|
|
267
267
|
| [Wide&Deep](nextrec/models/ranking/widedeep.py) | Wide & Deep Learning for Recommender Systems | DLRS 2016 | 已支持 |
|
|
268
268
|
| [xDeepFM](nextrec/models/ranking/xdeepfm.py) | xDeepFM: Combining Explicit and Implicit Feature Interactions | KDD 2018 | 已支持 |
|
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|

|
|
8
8
|

|
|
9
9
|

|
|
10
|
-

|
|
11
11
|
|
|
12
12
|
中文文档 | [English Version](README_en.md)
|
|
13
13
|
|
|
@@ -40,7 +40,7 @@ NextRec是一个基于PyTorch的现代推荐系统框架,旨在为研究工程
|
|
|
40
40
|
|
|
41
41
|
## NextRec近期进展
|
|
42
42
|
|
|
43
|
-
- **12/12/2025** 在v0.4.
|
|
43
|
+
- **12/12/2025** 在v0.4.12中加入了[RQ-VAE](/nextrec/models/representation/rqvae.py)模块。配套的[数据集](/dataset/ecommerce_task.csv)和[代码](tutorials/notebooks/zh/使用RQ-VAE构建语义ID.ipynb)已经同步在仓库中
|
|
44
44
|
- **07/12/2025** 发布了NextRec CLI命令行工具,它允许用户根据配置文件进行一键训练和推理,我们提供了相关的[教程](/nextrec_cli_preset/NextRec-CLI_zh.md)和[教学代码](/nextrec_cli_preset)
|
|
45
45
|
- **03/12/2025** NextRec获得了100颗🌟!感谢大家的支持
|
|
46
46
|
- **06/12/2025** 在v0.4.1中支持了单机多卡的分布式DDP训练,并且提供了配套的[代码](tutorials/distributed)
|
|
@@ -181,11 +181,11 @@ nextrec --mode=train --train_config=path/to/train_config.yaml
|
|
|
181
181
|
nextrec --mode=predict --predict_config=path/to/predict_config.yaml
|
|
182
182
|
```
|
|
183
183
|
|
|
184
|
-
> 截止当前版本0.4.
|
|
184
|
+
> 截止当前版本0.4.12,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
|
|
185
185
|
|
|
186
186
|
## 兼容平台
|
|
187
187
|
|
|
188
|
-
当前最新版本为0.4.
|
|
188
|
+
当前最新版本为0.4.12,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
|
|
189
189
|
|
|
190
190
|
| 平台 | 配置 |
|
|
191
191
|
|------|------|
|
|
@@ -203,7 +203,7 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
|
|
|
203
203
|
| [FM](nextrec/models/ranking/fm.py) | Factorization Machines | ICDM 2010 | 已支持 |
|
|
204
204
|
| [LR](nextrec/models/ranking/lr.py) | Logistic Regression | - | 已支持 |
|
|
205
205
|
| [AFM](nextrec/models/ranking/afm.py) | Attentional Factorization Machines: Learning the Weight of Feature Interactions via Attention Networks | IJCAI 2017 | 已支持 |
|
|
206
|
-
| [FFM](nextrec/models/ranking/ffm.py) | Field-aware Factorization Machines | RecSys
|
|
206
|
+
| [FFM](nextrec/models/ranking/ffm.py) | Field-aware Factorization Machines | RecSys 2016 | 已支持 |
|
|
207
207
|
| [DeepFM](nextrec/models/ranking/deepfm.py) | DeepFM: A Factorization-Machine based Neural Network for CTR Prediction | IJCAI 2017 | 已支持 |
|
|
208
208
|
| [Wide&Deep](nextrec/models/ranking/widedeep.py) | Wide & Deep Learning for Recommender Systems | DLRS 2016 | 已支持 |
|
|
209
209
|
| [xDeepFM](nextrec/models/ranking/xdeepfm.py) | xDeepFM: Combining Explicit and Implicit Feature Interactions | KDD 2018 | 已支持 |
|
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|

|
|
8
8
|

|
|
9
9
|

|
|
10
|
-

|
|
11
11
|
|
|
12
12
|
English | [中文文档](README.md)
|
|
13
13
|
|
|
@@ -185,11 +185,11 @@ nextrec --mode=train --train_config=path/to/train_config.yaml
|
|
|
185
185
|
nextrec --mode=predict --predict_config=path/to/predict_config.yaml
|
|
186
186
|
```
|
|
187
187
|
|
|
188
|
-
> As of version 0.4.
|
|
188
|
+
> As of version 0.4.12, NextRec CLI supports single-machine training; distributed training features are currently under development.
|
|
189
189
|
|
|
190
190
|
## Platform Compatibility
|
|
191
191
|
|
|
192
|
-
The current version is 0.4.
|
|
192
|
+
The current version is 0.4.12. All models and test code have been validated on the following platforms. If you encounter compatibility issues, please report them in the issue tracker with your system version:
|
|
193
193
|
|
|
194
194
|
| Platform | Configuration |
|
|
195
195
|
|----------|---------------|
|
|
@@ -209,7 +209,7 @@ The current version is 0.4.11. All models and test code have been validated on t
|
|
|
209
209
|
| [FM](nextrec/models/ranking/fm.py) | Factorization Machines | ICDM 2010 | Supported |
|
|
210
210
|
| [LR](nextrec/models/ranking/lr.py) | Logistic Regression | - | Supported |
|
|
211
211
|
| [AFM](nextrec/models/ranking/afm.py) | Attentional Factorization Machines: Learning the Weight of Feature Interactions via Attention Networks | IJCAI 2017 | Supported |
|
|
212
|
-
| [FFM](nextrec/models/ranking/ffm.py) | Field-aware Factorization Machines | RecSys
|
|
212
|
+
| [FFM](nextrec/models/ranking/ffm.py) | Field-aware Factorization Machines | RecSys 2016 | Supported |
|
|
213
213
|
| [DeepFM](nextrec/models/ranking/deepfm.py) | DeepFM: A Factorization-Machine based Neural Network for CTR Prediction | IJCAI 2017 | Supported |
|
|
214
214
|
| [Wide&Deep](nextrec/models/ranking/widedeep.py) | Wide & Deep Learning for Recommender Systems | DLRS 2016 | Supported |
|
|
215
215
|
| [xDeepFM](nextrec/models/ranking/xdeepfm.py) | xDeepFM: Combining Explicit and Implicit Feature Interactions | KDD 2018 | Supported |
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.4.12"
|
|
@@ -22,10 +22,10 @@ class Callback:
|
|
|
22
22
|
"""
|
|
23
23
|
Base callback.
|
|
24
24
|
|
|
25
|
-
Notes
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
25
|
+
Notes for DDP training:
|
|
26
|
+
In distributed training, the training loop runs on every rank.
|
|
27
|
+
For callbacks with side effects (saving, logging, etc.), set
|
|
28
|
+
``run_on_main_process_only=True`` to avoid multi-rank duplication.
|
|
29
29
|
"""
|
|
30
30
|
|
|
31
31
|
run_on_main_process_only: bool = False
|
|
@@ -70,7 +70,7 @@ class Callback:
|
|
|
70
70
|
|
|
71
71
|
|
|
72
72
|
class CallbackList:
|
|
73
|
-
"""
|
|
73
|
+
"""Generates a list of callbacks"""
|
|
74
74
|
|
|
75
75
|
def __init__(self, callbacks: Optional[list[Callback]] = None):
|
|
76
76
|
self.callbacks = callbacks or []
|
|
@@ -78,61 +78,41 @@ class CallbackList:
|
|
|
78
78
|
def append(self, callback: Callback):
|
|
79
79
|
self.callbacks.append(callback)
|
|
80
80
|
|
|
81
|
-
def
|
|
81
|
+
def call(self, fn_name: str, *args, **kwargs):
|
|
82
82
|
for callback in self.callbacks:
|
|
83
|
-
callback.
|
|
83
|
+
if not callback.should_run():
|
|
84
|
+
continue
|
|
85
|
+
getattr(callback, fn_name)(*args, **kwargs)
|
|
86
|
+
|
|
87
|
+
def set_model(self, model):
|
|
88
|
+
self.call("set_model", model)
|
|
84
89
|
|
|
85
90
|
def set_params(self, params: dict):
|
|
86
|
-
|
|
87
|
-
callback.set_params(params)
|
|
91
|
+
self.call("set_params", params)
|
|
88
92
|
|
|
89
93
|
def on_train_begin(self, logs: Optional[dict] = None):
|
|
90
|
-
|
|
91
|
-
if not callback.should_run():
|
|
92
|
-
continue
|
|
93
|
-
callback.on_train_begin(logs)
|
|
94
|
+
self.call("on_train_begin", logs)
|
|
94
95
|
|
|
95
96
|
def on_train_end(self, logs: Optional[dict] = None):
|
|
96
|
-
|
|
97
|
-
if not callback.should_run():
|
|
98
|
-
continue
|
|
99
|
-
callback.on_train_end(logs)
|
|
97
|
+
self.call("on_train_end", logs)
|
|
100
98
|
|
|
101
99
|
def on_epoch_begin(self, epoch: int, logs: Optional[dict] = None):
|
|
102
|
-
|
|
103
|
-
if not callback.should_run():
|
|
104
|
-
continue
|
|
105
|
-
callback.on_epoch_begin(epoch, logs)
|
|
100
|
+
self.call("on_epoch_begin", epoch, logs)
|
|
106
101
|
|
|
107
102
|
def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
|
|
108
|
-
|
|
109
|
-
if not callback.should_run():
|
|
110
|
-
continue
|
|
111
|
-
callback.on_epoch_end(epoch, logs)
|
|
103
|
+
self.call("on_epoch_end", epoch, logs)
|
|
112
104
|
|
|
113
105
|
def on_batch_begin(self, batch: int, logs: Optional[dict] = None):
|
|
114
|
-
|
|
115
|
-
if not callback.should_run():
|
|
116
|
-
continue
|
|
117
|
-
callback.on_batch_begin(batch, logs)
|
|
106
|
+
self.call("on_batch_begin", batch, logs)
|
|
118
107
|
|
|
119
108
|
def on_batch_end(self, batch: int, logs: Optional[dict] = None):
|
|
120
|
-
|
|
121
|
-
if not callback.should_run():
|
|
122
|
-
continue
|
|
123
|
-
callback.on_batch_end(batch, logs)
|
|
109
|
+
self.call("on_batch_end", batch, logs)
|
|
124
110
|
|
|
125
111
|
def on_validation_begin(self, logs: Optional[dict] = None):
|
|
126
|
-
|
|
127
|
-
if not callback.should_run():
|
|
128
|
-
continue
|
|
129
|
-
callback.on_validation_begin(logs)
|
|
112
|
+
self.call("on_validation_begin", logs)
|
|
130
113
|
|
|
131
114
|
def on_validation_end(self, logs: Optional[dict] = None):
|
|
132
|
-
|
|
133
|
-
if not callback.should_run():
|
|
134
|
-
continue
|
|
135
|
-
callback.on_validation_end(logs)
|
|
115
|
+
self.call("on_validation_end", logs)
|
|
136
116
|
|
|
137
117
|
|
|
138
118
|
class EarlyStopper(Callback):
|
|
@@ -146,6 +126,20 @@ class EarlyStopper(Callback):
|
|
|
146
126
|
restore_best_weights: bool = True,
|
|
147
127
|
verbose: int = 1,
|
|
148
128
|
):
|
|
129
|
+
"""
|
|
130
|
+
Callback to stop training early if no improvement.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
monitor: Metric name to monitor.
|
|
134
|
+
patience: Number of epochs with no improvement after which training will be stopped.
|
|
135
|
+
mode: One of {'min', 'max'}. In 'min' mode, training will stop when the
|
|
136
|
+
monitored metric has stopped decreasing; in 'max' mode it will stop
|
|
137
|
+
when the monitored metric has stopped increasing.
|
|
138
|
+
min_delta: Minimum change in the monitored metric to qualify as an improvement.
|
|
139
|
+
restore_best_weights: Whether to restore model weights from the epoch with the best value
|
|
140
|
+
of the monitored metric.
|
|
141
|
+
verbose: Verbosity mode. 1: messages will be printed. 0: silent.
|
|
142
|
+
"""
|
|
149
143
|
super().__init__()
|
|
150
144
|
self.monitor = monitor
|
|
151
145
|
self.patience = patience
|
|
@@ -233,6 +227,7 @@ class CheckpointSaver(Callback):
|
|
|
233
227
|
save_best_only: If True, only save when the model is considered the "best".
|
|
234
228
|
save_freq: Frequency of checkpoint saving ('epoch' or integer for every N epochs).
|
|
235
229
|
verbose: Verbosity mode.
|
|
230
|
+
run_on_main_process_only: Whether to run this callback only on the main process in DDP.
|
|
236
231
|
"""
|
|
237
232
|
|
|
238
233
|
def __init__(
|
|
@@ -274,7 +269,6 @@ class CheckpointSaver(Callback):
|
|
|
274
269
|
self.checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
|
|
275
270
|
|
|
276
271
|
def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
|
|
277
|
-
logging.info("")
|
|
278
272
|
logs = logs or {}
|
|
279
273
|
|
|
280
274
|
should_save = False
|
|
@@ -283,9 +277,6 @@ class CheckpointSaver(Callback):
|
|
|
283
277
|
elif isinstance(self.save_freq, int) and (epoch + 1) % self.save_freq == 0:
|
|
284
278
|
should_save = True
|
|
285
279
|
|
|
286
|
-
if not should_save and self.save_best_only:
|
|
287
|
-
should_save = False
|
|
288
|
-
|
|
289
280
|
# Check if this is the best model
|
|
290
281
|
current = logs.get(self.monitor)
|
|
291
282
|
is_best = False
|
|
@@ -297,11 +288,7 @@ class CheckpointSaver(Callback):
|
|
|
297
288
|
|
|
298
289
|
if should_save:
|
|
299
290
|
if not self.save_best_only or is_best:
|
|
300
|
-
checkpoint_path
|
|
301
|
-
self.checkpoint_path.parent
|
|
302
|
-
/ f"{self.checkpoint_path.stem}{self.checkpoint_path.suffix}"
|
|
303
|
-
)
|
|
304
|
-
self.save_checkpoint(checkpoint_path, epoch, logs)
|
|
291
|
+
self.save_checkpoint(self.checkpoint_path, epoch, logs)
|
|
305
292
|
|
|
306
293
|
if is_best:
|
|
307
294
|
# Use save_path directly without adding _best suffix since it may already contain it
|
|
@@ -371,7 +358,9 @@ class LearningRateScheduler(Callback):
|
|
|
371
358
|
# Step the scheduler
|
|
372
359
|
if hasattr(self.scheduler, "step"):
|
|
373
360
|
# Some schedulers need metrics
|
|
374
|
-
if
|
|
361
|
+
if logs is None:
|
|
362
|
+
logs = {}
|
|
363
|
+
if "val_loss" in logs and hasattr(self.scheduler, "mode"):
|
|
375
364
|
self.scheduler.step(logs["val_loss"])
|
|
376
365
|
else:
|
|
377
366
|
self.scheduler.step()
|
|
@@ -399,7 +388,6 @@ class MetricsLogger(Callback):
|
|
|
399
388
|
self.run_on_main_process_only = True
|
|
400
389
|
self.log_freq = log_freq
|
|
401
390
|
self.verbose = verbose
|
|
402
|
-
self.batch_count = 0
|
|
403
391
|
|
|
404
392
|
def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
|
|
405
393
|
if self.verbose > 0 and (
|
|
@@ -416,8 +404,10 @@ class MetricsLogger(Callback):
|
|
|
416
404
|
logging.info(f"Epoch {epoch + 1}: {metrics_str}")
|
|
417
405
|
|
|
418
406
|
def on_batch_end(self, batch: int, logs: Optional[dict] = None):
|
|
419
|
-
self.
|
|
420
|
-
|
|
407
|
+
if self.verbose > 1 and (
|
|
408
|
+
self.log_freq == "batch"
|
|
409
|
+
or (isinstance(self.log_freq, int) and (batch + 1) % self.log_freq == 0)
|
|
410
|
+
):
|
|
421
411
|
logs = logs or {}
|
|
422
412
|
metrics_str = " - ".join(
|
|
423
413
|
[
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Feature definitions
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 20/12/2025
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -12,22 +12,20 @@ from nextrec.utils.embedding import get_auto_embedding_dim
|
|
|
12
12
|
from nextrec.utils.feature import normalize_to_list
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
class BaseFeature
|
|
15
|
+
class BaseFeature:
|
|
16
16
|
def __repr__(self):
|
|
17
17
|
params = {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
|
|
18
18
|
param_str = ", ".join(f"{k}={v!r}" for k, v in params.items())
|
|
19
19
|
return f"{self.__class__.__name__}({param_str})"
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
class
|
|
22
|
+
class EmbeddingFeature(BaseFeature):
|
|
23
23
|
def __init__(
|
|
24
24
|
self,
|
|
25
25
|
name: str,
|
|
26
26
|
vocab_size: int,
|
|
27
|
-
max_len: int = 20,
|
|
28
27
|
embedding_name: str = "",
|
|
29
28
|
embedding_dim: int | None = 4,
|
|
30
|
-
combiner: str = "mean",
|
|
31
29
|
padding_idx: int | None = None,
|
|
32
30
|
init_type: str = "normal",
|
|
33
31
|
init_params: dict | None = None,
|
|
@@ -39,13 +37,15 @@ class SequenceFeature(BaseFeature):
|
|
|
39
37
|
):
|
|
40
38
|
self.name = name
|
|
41
39
|
self.vocab_size = vocab_size
|
|
42
|
-
self.max_len = max_len
|
|
43
40
|
self.embedding_name = embedding_name or name
|
|
44
|
-
self.embedding_dim =
|
|
41
|
+
self.embedding_dim = (
|
|
42
|
+
get_auto_embedding_dim(vocab_size)
|
|
43
|
+
if embedding_dim is None
|
|
44
|
+
else embedding_dim
|
|
45
|
+
)
|
|
45
46
|
|
|
46
47
|
self.init_type = init_type
|
|
47
48
|
self.init_params = init_params or {}
|
|
48
|
-
self.combiner = combiner
|
|
49
49
|
self.padding_idx = padding_idx
|
|
50
50
|
self.l1_reg = l1_reg
|
|
51
51
|
self.l2_reg = l2_reg
|
|
@@ -54,13 +54,15 @@ class SequenceFeature(BaseFeature):
|
|
|
54
54
|
self.freeze_pretrained = freeze_pretrained
|
|
55
55
|
|
|
56
56
|
|
|
57
|
-
class
|
|
57
|
+
class SequenceFeature(EmbeddingFeature):
|
|
58
58
|
def __init__(
|
|
59
59
|
self,
|
|
60
60
|
name: str,
|
|
61
61
|
vocab_size: int,
|
|
62
|
+
max_len: int = 20,
|
|
62
63
|
embedding_name: str = "",
|
|
63
64
|
embedding_dim: int | None = 4,
|
|
65
|
+
combiner: str = "mean",
|
|
64
66
|
padding_idx: int | None = None,
|
|
65
67
|
init_type: str = "normal",
|
|
66
68
|
init_params: dict | None = None,
|
|
@@ -70,19 +72,26 @@ class SparseFeature(BaseFeature):
|
|
|
70
72
|
pretrained_weight: torch.Tensor | None = None,
|
|
71
73
|
freeze_pretrained: bool = False,
|
|
72
74
|
):
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
75
|
+
super().__init__(
|
|
76
|
+
name=name,
|
|
77
|
+
vocab_size=vocab_size,
|
|
78
|
+
embedding_name=embedding_name,
|
|
79
|
+
embedding_dim=embedding_dim,
|
|
80
|
+
padding_idx=padding_idx,
|
|
81
|
+
init_type=init_type,
|
|
82
|
+
init_params=init_params,
|
|
83
|
+
l1_reg=l1_reg,
|
|
84
|
+
l2_reg=l2_reg,
|
|
85
|
+
trainable=trainable,
|
|
86
|
+
pretrained_weight=pretrained_weight,
|
|
87
|
+
freeze_pretrained=freeze_pretrained,
|
|
88
|
+
)
|
|
89
|
+
self.max_len = max_len
|
|
90
|
+
self.combiner = combiner
|
|
77
91
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
self.l1_reg = l1_reg
|
|
82
|
-
self.l2_reg = l2_reg
|
|
83
|
-
self.trainable = trainable
|
|
84
|
-
self.pretrained_weight = pretrained_weight
|
|
85
|
-
self.freeze_pretrained = freeze_pretrained
|
|
92
|
+
|
|
93
|
+
class SparseFeature(EmbeddingFeature):
|
|
94
|
+
pass
|
|
86
95
|
|
|
87
96
|
|
|
88
97
|
class DenseFeature(BaseFeature):
|
|
@@ -95,7 +104,11 @@ class DenseFeature(BaseFeature):
|
|
|
95
104
|
):
|
|
96
105
|
self.name = name
|
|
97
106
|
self.input_dim = max(int(input_dim or 1), 1)
|
|
98
|
-
self.embedding_dim = embedding_dim
|
|
107
|
+
self.embedding_dim = self.input_dim if embedding_dim is None else embedding_dim
|
|
108
|
+
if use_embedding and self.embedding_dim == 0:
|
|
109
|
+
raise ValueError(
|
|
110
|
+
"[Features Error] DenseFeature: use_embedding=True is incompatible with embedding_dim=0"
|
|
111
|
+
)
|
|
99
112
|
if embedding_dim is not None and embedding_dim > 1:
|
|
100
113
|
self.use_embedding = True
|
|
101
114
|
else:
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Layer implementations used across NextRec models.
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 20/12/2025
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -28,6 +28,16 @@ class PredictionLayer(nn.Module):
|
|
|
28
28
|
use_bias: bool = True,
|
|
29
29
|
return_logits: bool = False,
|
|
30
30
|
):
|
|
31
|
+
"""
|
|
32
|
+
Prediction layer supporting binary and regression outputs.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
task_type: A string or list of strings specifying the type of each task. supported types are "binary" and "regression".
|
|
36
|
+
task_dims: An integer or list of integers specifying the output dimension for each task.
|
|
37
|
+
If None, defaults to 1 for each task. If a single integer is provided, it is shared across all tasks.
|
|
38
|
+
use_bias: Whether to include a bias term in the prediction layer.
|
|
39
|
+
return_logits: If True, returns raw logits without applying activation functions.
|
|
40
|
+
"""
|
|
31
41
|
super().__init__()
|
|
32
42
|
self.task_types = [task_type] if isinstance(task_type, str) else list(task_type)
|
|
33
43
|
if len(self.task_types) == 0:
|
|
@@ -253,8 +263,11 @@ class EmbeddingLayer(nn.Module):
|
|
|
253
263
|
for feat in unique_feats.values():
|
|
254
264
|
if isinstance(feat, DenseFeature):
|
|
255
265
|
in_dim = max(int(getattr(feat, "input_dim", 1)), 1)
|
|
256
|
-
|
|
257
|
-
|
|
266
|
+
if getattr(feat, "use_embedding", False):
|
|
267
|
+
emb_dim = getattr(feat, "embedding_dim", None)
|
|
268
|
+
out_dim = max(int(emb_dim), 1) if emb_dim else in_dim
|
|
269
|
+
else:
|
|
270
|
+
out_dim = in_dim
|
|
258
271
|
dim += out_dim
|
|
259
272
|
elif isinstance(feat, SequenceFeature) and feat.combiner == "concat":
|
|
260
273
|
dim += feat.embedding_dim * feat.max_len
|
|
@@ -518,13 +531,17 @@ class MultiHeadSelfAttention(nn.Module):
|
|
|
518
531
|
self.use_residual = use_residual
|
|
519
532
|
self.dropout_rate = dropout
|
|
520
533
|
|
|
521
|
-
self.W_Q = nn.Linear(
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
self.
|
|
534
|
+
self.W_Q = nn.Linear(
|
|
535
|
+
embedding_dim, embedding_dim, bias=False
|
|
536
|
+
) # Query projection
|
|
537
|
+
self.W_K = nn.Linear(embedding_dim, embedding_dim, bias=False) # Key projection
|
|
538
|
+
self.W_V = nn.Linear(
|
|
539
|
+
embedding_dim, embedding_dim, bias=False
|
|
540
|
+
) # Value projection
|
|
541
|
+
self.W_O = nn.Linear(
|
|
542
|
+
embedding_dim, embedding_dim, bias=False
|
|
543
|
+
) # Output projection
|
|
525
544
|
|
|
526
|
-
if self.use_residual:
|
|
527
|
-
self.W_Res = nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
528
545
|
if use_layer_norm:
|
|
529
546
|
self.layer_norm = nn.LayerNorm(embedding_dim)
|
|
530
547
|
else:
|
|
@@ -537,81 +554,60 @@ class MultiHeadSelfAttention(nn.Module):
|
|
|
537
554
|
def forward(
|
|
538
555
|
self, x: torch.Tensor, attention_mask: torch.Tensor | None = None
|
|
539
556
|
) -> torch.Tensor:
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
Returns:
|
|
545
|
-
output: [batch_size, seq_len, embedding_dim]
|
|
546
|
-
"""
|
|
547
|
-
batch_size, seq_len, _ = x.shape
|
|
548
|
-
Q = self.W_Q(x) # [batch_size, seq_len, embedding_dim]
|
|
557
|
+
# x: [Batch, Length, Dim]
|
|
558
|
+
B, L, D = x.shape
|
|
559
|
+
|
|
560
|
+
Q = self.W_Q(x)
|
|
549
561
|
K = self.W_K(x)
|
|
550
562
|
V = self.W_V(x)
|
|
551
563
|
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
564
|
+
Q = Q.view(B, L, self.num_heads, self.head_dim).transpose(
|
|
565
|
+
1, 2
|
|
566
|
+
) # [Batch, Heads, Length, head_dim]
|
|
567
|
+
K = K.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
|
|
568
|
+
V = V.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
|
|
569
|
+
|
|
570
|
+
key_padding_mask = None
|
|
571
|
+
if attention_mask is not None:
|
|
572
|
+
if attention_mask.dim() == 2: # [B,L], 1=valid, 0=pad
|
|
573
|
+
key_padding_mask = ~attention_mask.bool()
|
|
574
|
+
attn_mask = key_padding_mask[:, None, None, :]
|
|
575
|
+
attn_mask = attn_mask.expand(B, 1, L, L)
|
|
576
|
+
elif attention_mask.dim() == 3: # [B,L,L], 1=allowed, 0=masked
|
|
577
|
+
attn_mask = (~attention_mask.bool()).view(B, 1, L, L)
|
|
578
|
+
else:
|
|
579
|
+
raise ValueError("attention_mask must be [B,L] or [B,L,L]")
|
|
580
|
+
else:
|
|
581
|
+
attn_mask = None
|
|
556
582
|
|
|
557
583
|
if self.use_flash_attention:
|
|
558
|
-
|
|
559
|
-
if attention_mask is not None:
|
|
560
|
-
# Convert mask to [batch_size, 1, seq_len, seq_len] format
|
|
561
|
-
if attention_mask.dim() == 2:
|
|
562
|
-
# [B, L] -> [B, 1, 1, L]
|
|
563
|
-
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
564
|
-
elif attention_mask.dim() == 3:
|
|
565
|
-
# [B, L, L] -> [B, 1, L, L]
|
|
566
|
-
attention_mask = attention_mask.unsqueeze(1)
|
|
567
|
-
attention_output = F.scaled_dot_product_attention(
|
|
584
|
+
attn = F.scaled_dot_product_attention(
|
|
568
585
|
Q,
|
|
569
586
|
K,
|
|
570
587
|
V,
|
|
571
|
-
attn_mask=
|
|
588
|
+
attn_mask=attn_mask,
|
|
572
589
|
dropout_p=self.dropout_rate if self.training else 0.0,
|
|
573
|
-
)
|
|
574
|
-
# Handle potential NaN values
|
|
575
|
-
attention_output = torch.nan_to_num(attention_output, nan=0.0)
|
|
590
|
+
) # [B,H,L,dh]
|
|
576
591
|
else:
|
|
577
|
-
# Fallback to standard attention
|
|
578
592
|
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim**0.5)
|
|
593
|
+
if attn_mask is not None:
|
|
594
|
+
scores = scores.masked_fill(attn_mask, float("-inf"))
|
|
595
|
+
attn_weights = torch.softmax(scores, dim=-1)
|
|
596
|
+
attn_weights = self.dropout(attn_weights)
|
|
597
|
+
attn = torch.matmul(attn_weights, V) # [B,H,L,dh]
|
|
579
598
|
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
if attention_mask.dim() == 2:
|
|
583
|
-
# [B, L] -> [B, 1, 1, L]
|
|
584
|
-
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
585
|
-
elif attention_mask.dim() == 3:
|
|
586
|
-
# [B, L, L] -> [B, 1, L, L]
|
|
587
|
-
attention_mask = attention_mask.unsqueeze(1)
|
|
588
|
-
scores = scores.masked_fill(~attention_mask, float("-1e9"))
|
|
589
|
-
|
|
590
|
-
attention_weights = F.softmax(scores, dim=-1)
|
|
591
|
-
attention_weights = self.dropout(attention_weights)
|
|
592
|
-
attention_output = torch.matmul(
|
|
593
|
-
attention_weights, V
|
|
594
|
-
) # [batch_size, num_heads, seq_len, head_dim]
|
|
595
|
-
|
|
596
|
-
# Concatenate heads
|
|
597
|
-
attention_output = attention_output.transpose(1, 2).contiguous()
|
|
598
|
-
attention_output = attention_output.view(
|
|
599
|
-
batch_size, seq_len, self.embedding_dim
|
|
600
|
-
)
|
|
599
|
+
attn = attn.transpose(1, 2).contiguous().view(B, L, D)
|
|
600
|
+
out = self.W_O(attn)
|
|
601
601
|
|
|
602
|
-
# Output projection
|
|
603
|
-
output = self.W_O(attention_output)
|
|
604
|
-
|
|
605
|
-
# Residual connection
|
|
606
602
|
if self.use_residual:
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
# Layer normalization
|
|
603
|
+
out = out + x
|
|
610
604
|
if self.layer_norm is not None:
|
|
611
|
-
|
|
605
|
+
out = self.layer_norm(out)
|
|
612
606
|
|
|
613
|
-
|
|
614
|
-
|
|
607
|
+
if key_padding_mask is not None:
|
|
608
|
+
out = out * (~key_padding_mask).unsqueeze(-1)
|
|
609
|
+
|
|
610
|
+
return out
|
|
615
611
|
|
|
616
612
|
|
|
617
613
|
class AttentionPoolingLayer(nn.Module):
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
NextRec Basic Loggers
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 20/12/2025
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -185,7 +185,7 @@ class TrainingLogger:
|
|
|
185
185
|
) -> dict[str, float]:
|
|
186
186
|
formatted: dict[str, float] = {}
|
|
187
187
|
for key, value in metrics.items():
|
|
188
|
-
if isinstance(value, numbers.
|
|
188
|
+
if isinstance(value, numbers.Real):
|
|
189
189
|
formatted[f"{split}/{key}"] = float(value)
|
|
190
190
|
elif hasattr(value, "item"):
|
|
191
191
|
try:
|