nextrec 0.4.8__tar.gz → 0.4.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nextrec-0.4.8 → nextrec-0.4.10}/PKG-INFO +6 -7
- {nextrec-0.4.8 → nextrec-0.4.10}/README.md +4 -5
- {nextrec-0.4.8 → nextrec-0.4.10}/README_en.md +4 -5
- {nextrec-0.4.8 → nextrec-0.4.10}/docs/en/Getting started guide.md +1 -1
- {nextrec-0.4.8 → nextrec-0.4.10}/docs/rtd/conf.py +1 -1
- nextrec-0.4.10/docs/rtd/nextrec.utils.rst +69 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/docs/zh//345/277/253/351/200/237/344/270/212/346/211/213.md +1 -1
- nextrec-0.4.10/nextrec/__version__.py +1 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/basic/callback.py +30 -15
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/basic/features.py +1 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/basic/layers.py +6 -8
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/basic/loggers.py +14 -7
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/basic/metrics.py +6 -76
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/basic/model.py +316 -321
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/cli.py +185 -43
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/data/__init__.py +13 -16
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/data/batch_utils.py +3 -2
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/data/data_processing.py +10 -2
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/data/data_utils.py +9 -14
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/data/dataloader.py +31 -33
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/data/preprocessor.py +328 -255
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/loss/__init__.py +1 -5
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/loss/loss_utils.py +2 -8
- nextrec-0.4.10/nextrec/models/generative/__init__.py +9 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/models/generative/hstu.py +6 -4
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/models/multi_task/esmm.py +2 -2
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/models/multi_task/mmoe.py +2 -2
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/models/multi_task/ple.py +2 -2
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/models/multi_task/poso.py +2 -3
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/models/multi_task/share_bottom.py +2 -2
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/models/ranking/afm.py +2 -2
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/models/ranking/autoint.py +2 -2
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/models/ranking/dcn.py +2 -2
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/models/ranking/dcn_v2.py +2 -2
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/models/ranking/deepfm.py +6 -7
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/models/ranking/dien.py +3 -3
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/models/ranking/din.py +3 -3
- nextrec-0.4.10/nextrec/models/ranking/eulernet.py +365 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/models/ranking/fibinet.py +5 -5
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/models/ranking/fm.py +3 -7
- nextrec-0.4.10/nextrec/models/ranking/lr.py +120 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/models/ranking/masknet.py +2 -2
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/models/ranking/pnn.py +2 -2
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/models/ranking/widedeep.py +2 -2
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/models/ranking/xdeepfm.py +2 -2
- nextrec-0.4.10/nextrec/models/representation/__init__.py +9 -0
- {nextrec-0.4.8/nextrec/models/generative → nextrec-0.4.10/nextrec/models/representation}/rqvae.py +9 -9
- nextrec-0.4.10/nextrec/models/retrieval/__init__.py +0 -0
- {nextrec-0.4.8/nextrec/models/match → nextrec-0.4.10/nextrec/models/retrieval}/dssm.py +8 -3
- {nextrec-0.4.8/nextrec/models/match → nextrec-0.4.10/nextrec/models/retrieval}/dssm_v2.py +8 -3
- {nextrec-0.4.8/nextrec/models/match → nextrec-0.4.10/nextrec/models/retrieval}/mind.py +4 -3
- {nextrec-0.4.8/nextrec/models/match → nextrec-0.4.10/nextrec/models/retrieval}/sdm.py +4 -3
- {nextrec-0.4.8/nextrec/models/match → nextrec-0.4.10/nextrec/models/retrieval}/youtube_dnn.py +8 -3
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/utils/__init__.py +60 -46
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/utils/config.py +8 -7
- nextrec-0.4.10/nextrec/utils/console.py +371 -0
- nextrec-0.4.8/nextrec/utils/synthetic_data.py → nextrec-0.4.10/nextrec/utils/data.py +102 -15
- nextrec-0.4.10/nextrec/utils/feature.py +29 -0
- nextrec-0.4.10/nextrec/utils/torch_utils.py +411 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec_cli_preset/NextRec-CLI.md +97 -64
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec_cli_preset/NextRec-CLI_zh.md +92 -59
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec_cli_preset/feature_config.yaml +2 -2
- nextrec-0.4.10/nextrec_cli_preset/predict_config.yaml +32 -0
- nextrec-0.4.10/nextrec_cli_preset/predict_config_template.yaml +64 -0
- nextrec-0.4.10/nextrec_cli_preset/train_config.yaml +37 -0
- nextrec-0.4.10/nextrec_cli_preset/train_config_template.yaml +149 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/pyproject.toml +2 -2
- {nextrec-0.4.8 → nextrec-0.4.10}/requirements.txt +1 -1
- {nextrec-0.4.8 → nextrec-0.4.10}/scripts/format_code.py +15 -2
- {nextrec-0.4.8 → nextrec-0.4.10}/test/conftest.py +4 -3
- nextrec-0.4.8/test/test_utils.py → nextrec-0.4.10/test/helpers.py +20 -65
- nextrec-0.4.10/test/test_generative_models.py +303 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/test/test_layers.py +5 -14
- {nextrec-0.4.8 → nextrec-0.4.10}/test/test_match_models.py +11 -11
- {nextrec-0.4.8 → nextrec-0.4.10}/test/test_multitask_models.py +10 -10
- {nextrec-0.4.8 → nextrec-0.4.10}/test/test_ranking_models.py +107 -17
- nextrec-0.4.10/test/test_utils_console.py +124 -0
- nextrec-0.4.10/test/test_utils_data.py +217 -0
- nextrec-0.4.10/test/test_utils_embedding.py +72 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/tutorials/example_match_dssm.py +4 -4
- {nextrec-0.4.8 → nextrec-0.4.10}/tutorials/example_multitask.py +2 -2
- {nextrec-0.4.8 → nextrec-0.4.10}/tutorials/example_ranking_din.py +2 -2
- {nextrec-0.4.8 → nextrec-0.4.10}/tutorials/movielen_match_dssm.py +3 -3
- {nextrec-0.4.8 → nextrec-0.4.10}/tutorials/movielen_ranking_deepfm.py +2 -2
- {nextrec-0.4.8 → nextrec-0.4.10}/tutorials/notebooks/en/Build semantic ID with RQ-VAE.ipynb +34 -33
- {nextrec-0.4.8 → nextrec-0.4.10}/tutorials/notebooks/en/Hands on nextrec.ipynb +2 -2
- {nextrec-0.4.8 → nextrec-0.4.10}/tutorials/notebooks/zh//344/275/277/347/224/250RQ-VAE/346/236/204/345/273/272/350/257/255/344/271/211ID.ipynb +2 -2
- {nextrec-0.4.8 → nextrec-0.4.10}/tutorials/notebooks/zh//345/277/253/351/200/237/345/205/245/351/227/250nextrec.ipynb +2 -2
- {nextrec-0.4.8 → nextrec-0.4.10}/tutorials/run_all_match_models.py +3 -3
- {nextrec-0.4.8 → nextrec-0.4.10}/tutorials/run_all_ranking_models.py +4 -0
- nextrec-0.4.8/docs/rtd/nextrec.utils.rst +0 -37
- nextrec-0.4.8/nextrec/__version__.py +0 -1
- nextrec-0.4.8/nextrec/models/generative/__init__.py +0 -16
- nextrec-0.4.8/nextrec/utils/cli_utils.py +0 -58
- nextrec-0.4.8/nextrec/utils/device.py +0 -78
- nextrec-0.4.8/nextrec/utils/distributed.py +0 -141
- nextrec-0.4.8/nextrec/utils/feature.py +0 -14
- nextrec-0.4.8/nextrec/utils/file.py +0 -92
- nextrec-0.4.8/nextrec/utils/initializer.py +0 -79
- nextrec-0.4.8/nextrec/utils/optimizer.py +0 -75
- nextrec-0.4.8/nextrec/utils/tensor.py +0 -72
- nextrec-0.4.8/nextrec_cli_preset/predict_config.yaml +0 -24
- nextrec-0.4.8/nextrec_cli_preset/train_config.yaml +0 -45
- nextrec-0.4.8/test/test_generative_models.py +0 -890
- {nextrec-0.4.8 → nextrec-0.4.10}/.github/workflows/publish.yml +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/.github/workflows/tests.yml +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/.gitignore +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/.readthedocs.yaml +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/CODE_OF_CONDUCT.md +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/CONTRIBUTING.md +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/LICENSE +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/MANIFEST.in +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/assets/Feature Configuration.png +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/assets/Model Parameters.png +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/assets/Training Configuration.png +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/assets/Training logs.png +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/assets/logo.png +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/assets/mmoe_tutorial.png +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/assets/nextrec_diagram.png +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/assets/test data.png +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/dataset/ctcvr_task.csv +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/dataset/ecommerce_task.csv +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/dataset/match_task.csv +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/dataset/movielens_100k.csv +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/dataset/multitask_task.csv +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/dataset/ranking_task.csv +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/docs/rtd/Makefile +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/docs/rtd/index.md +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/docs/rtd/make.bat +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/docs/rtd/modules.rst +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/docs/rtd/nextrec.basic.rst +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/docs/rtd/nextrec.data.rst +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/docs/rtd/nextrec.loss.rst +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/docs/rtd/nextrec.rst +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/docs/rtd/requirements.txt +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/__init__.py +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/basic/__init__.py +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/basic/activation.py +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/basic/session.py +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/loss/listwise.py +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/loss/pairwise.py +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/loss/pointwise.py +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/models/generative/tiger.py +0 -0
- {nextrec-0.4.8/nextrec/models/match → nextrec-0.4.10/nextrec/models/multi_task}/__init__.py +0 -0
- {nextrec-0.4.8/nextrec/models/multi_task → nextrec-0.4.10/nextrec/models/ranking}/__init__.py +0 -0
- nextrec-0.4.8/nextrec/models/ranking/__init__.py → nextrec-0.4.10/nextrec/models/ranking/ffm.py +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/utils/embedding.py +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec/utils/model.py +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/afm.yaml +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/autoint.yaml +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/dcn.yaml +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/deepfm.yaml +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/din.yaml +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/esmm.yaml +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/fibinet.yaml +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/fm.yaml +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/masknet.yaml +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/mmoe.yaml +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/ple.yaml +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/pnn.yaml +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/poso.yaml +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/share_bottom.yaml +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/widedeep.yaml +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/nextrec_cli_preset/model_configs/xdeepfm.yaml +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/pytest.ini +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/test/__init__.py +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/test/run_tests.py +2 -2
- {nextrec-0.4.8 → nextrec-0.4.10}/test/test_base_model_regularization.py +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/test/test_losses.py +2 -2
- {nextrec-0.4.8 → nextrec-0.4.10}/test/test_preprocessor.py +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/test_requirements.txt +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/tutorials/distributed/example_distributed_training.py +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/tutorials/distributed/example_distributed_training_large_dataset.py +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/tutorials/notebooks/en/Hands on dataprocessor.ipynb +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/tutorials/notebooks/zh//345/246/202/344/275/225/344/275/277/347/224/250DataProcessor/350/277/233/350/241/214/351/242/204/345/244/204/347/220/206.ipynb" +0 -0
- {nextrec-0.4.8 → nextrec-0.4.10}/tutorials/run_all_multitask_models.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nextrec
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.10
|
|
4
4
|
Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
|
|
5
5
|
Project-URL: Homepage, https://github.com/zerolovesea/NextRec
|
|
6
6
|
Project-URL: Repository, https://github.com/zerolovesea/NextRec
|
|
@@ -33,6 +33,7 @@ Requires-Dist: pyarrow<15.0.0,>=12.0.0; sys_platform == 'win32'
|
|
|
33
33
|
Requires-Dist: pyarrow>=12.0.0; sys_platform == 'darwin'
|
|
34
34
|
Requires-Dist: pyarrow>=16.0.0; sys_platform == 'linux' and python_version >= '3.12'
|
|
35
35
|
Requires-Dist: pyyaml>=6.0
|
|
36
|
+
Requires-Dist: rich>=13.7.0
|
|
36
37
|
Requires-Dist: scikit-learn<2.0,>=1.2; sys_platform == 'linux' and python_version < '3.12'
|
|
37
38
|
Requires-Dist: scikit-learn>=1.3.0; sys_platform == 'darwin'
|
|
38
39
|
Requires-Dist: scikit-learn>=1.3.0; sys_platform == 'linux' and python_version >= '3.12'
|
|
@@ -43,7 +44,6 @@ Requires-Dist: scipy>=1.10.0; sys_platform == 'win32'
|
|
|
43
44
|
Requires-Dist: scipy>=1.11.0; sys_platform == 'linux' and python_version >= '3.12'
|
|
44
45
|
Requires-Dist: torch>=2.0.0
|
|
45
46
|
Requires-Dist: torchvision>=0.15.0
|
|
46
|
-
Requires-Dist: tqdm>=4.65.0
|
|
47
47
|
Requires-Dist: transformers>=4.38.0
|
|
48
48
|
Provides-Extra: dev
|
|
49
49
|
Requires-Dist: jupyter>=1.0.0; extra == 'dev'
|
|
@@ -66,7 +66,7 @@ Description-Content-Type: text/markdown
|
|
|
66
66
|

|
|
67
67
|

|
|
68
68
|

|
|
69
|
-

|
|
70
70
|
|
|
71
71
|
中文文档 | [English Version](README_en.md)
|
|
72
72
|
|
|
@@ -99,11 +99,10 @@ NextRec是一个基于PyTorch的现代推荐系统框架,旨在为研究工程
|
|
|
99
99
|
|
|
100
100
|
## NextRec近期进展
|
|
101
101
|
|
|
102
|
-
- **12/12/2025** 在v0.4.
|
|
102
|
+
- **12/12/2025** 在v0.4.10中加入了[RQ-VAE](/nextrec/models/representation/rqvae.py)模块。配套的[数据集](/dataset/ecommerce_task.csv)和[代码](tutorials/notebooks/zh/使用RQ-VAE构建语义ID.ipynb)已经同步在仓库中
|
|
103
103
|
- **07/12/2025** 发布了NextRec CLI命令行工具,它允许用户根据配置文件进行一键训练和推理,我们提供了相关的[教程](/nextrec_cli_preset/NextRec-CLI_zh.md)和[教学代码](/nextrec_cli_preset)
|
|
104
104
|
- **03/12/2025** NextRec获得了100颗🌟!感谢大家的支持
|
|
105
105
|
- **06/12/2025** 在v0.4.1中支持了单机多卡的分布式DDP训练,并且提供了配套的[代码](tutorials/distributed)
|
|
106
|
-
- **23/11/2025** 在v0.2.2中对basemodel进行了逻辑上的大幅重构和流程统一,并且对listwise/pairwise/pointwise损失进行了统一
|
|
107
106
|
- **11/11/2025** NextRec v0.1.0发布,我们提供了10余种Ranking模型,4种多任务模型和4种召回模型,以及统一的训练/日志/指标管理系统
|
|
108
107
|
|
|
109
108
|
## 架构
|
|
@@ -241,11 +240,11 @@ nextrec --mode=train --train_config=path/to/train_config.yaml
|
|
|
241
240
|
nextrec --mode=predict --predict_config=path/to/predict_config.yaml
|
|
242
241
|
```
|
|
243
242
|
|
|
244
|
-
> 截止当前版本0.4.
|
|
243
|
+
> 截止当前版本0.4.10,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
|
|
245
244
|
|
|
246
245
|
## 兼容平台
|
|
247
246
|
|
|
248
|
-
当前最新版本为0.4.
|
|
247
|
+
当前最新版本为0.4.10,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
|
|
249
248
|
|
|
250
249
|
| 平台 | 配置 |
|
|
251
250
|
|------|------|
|
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|

|
|
8
8
|

|
|
9
9
|

|
|
10
|
-

|
|
11
11
|
|
|
12
12
|
中文文档 | [English Version](README_en.md)
|
|
13
13
|
|
|
@@ -40,11 +40,10 @@ NextRec是一个基于PyTorch的现代推荐系统框架,旨在为研究工程
|
|
|
40
40
|
|
|
41
41
|
## NextRec近期进展
|
|
42
42
|
|
|
43
|
-
- **12/12/2025** 在v0.4.
|
|
43
|
+
- **12/12/2025** 在v0.4.10中加入了[RQ-VAE](/nextrec/models/representation/rqvae.py)模块。配套的[数据集](/dataset/ecommerce_task.csv)和[代码](tutorials/notebooks/zh/使用RQ-VAE构建语义ID.ipynb)已经同步在仓库中
|
|
44
44
|
- **07/12/2025** 发布了NextRec CLI命令行工具,它允许用户根据配置文件进行一键训练和推理,我们提供了相关的[教程](/nextrec_cli_preset/NextRec-CLI_zh.md)和[教学代码](/nextrec_cli_preset)
|
|
45
45
|
- **03/12/2025** NextRec获得了100颗🌟!感谢大家的支持
|
|
46
46
|
- **06/12/2025** 在v0.4.1中支持了单机多卡的分布式DDP训练,并且提供了配套的[代码](tutorials/distributed)
|
|
47
|
-
- **23/11/2025** 在v0.2.2中对basemodel进行了逻辑上的大幅重构和流程统一,并且对listwise/pairwise/pointwise损失进行了统一
|
|
48
47
|
- **11/11/2025** NextRec v0.1.0发布,我们提供了10余种Ranking模型,4种多任务模型和4种召回模型,以及统一的训练/日志/指标管理系统
|
|
49
48
|
|
|
50
49
|
## 架构
|
|
@@ -182,11 +181,11 @@ nextrec --mode=train --train_config=path/to/train_config.yaml
|
|
|
182
181
|
nextrec --mode=predict --predict_config=path/to/predict_config.yaml
|
|
183
182
|
```
|
|
184
183
|
|
|
185
|
-
> 截止当前版本0.4.
|
|
184
|
+
> 截止当前版本0.4.10,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
|
|
186
185
|
|
|
187
186
|
## 兼容平台
|
|
188
187
|
|
|
189
|
-
当前最新版本为0.4.
|
|
188
|
+
当前最新版本为0.4.10,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
|
|
190
189
|
|
|
191
190
|
| 平台 | 配置 |
|
|
192
191
|
|------|------|
|
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|

|
|
8
8
|

|
|
9
9
|

|
|
10
|
-

|
|
11
11
|
|
|
12
12
|
English | [中文文档](README.md)
|
|
13
13
|
|
|
@@ -42,11 +42,10 @@ NextRec is a modern recommendation framework built on PyTorch, delivering a unif
|
|
|
42
42
|
|
|
43
43
|
## NextRec Progress
|
|
44
44
|
|
|
45
|
-
- **12/12/2025** Added [RQ-VAE](/nextrec/models/
|
|
45
|
+
- **12/12/2025** Added [RQ-VAE](/nextrec/models/representation/rqvae.py), a common module for generative retrieval. Paired [dataset](/dataset/ecommerce_task.csv) and [notebook code](tutorials/notebooks/en/Build%20semantic%20ID%20with%20RQ-VAE.ipynb) are available.
|
|
46
46
|
- **07/12/2025** Released the NextRec CLI tool to run training/inference from configs. See the [guide](/nextrec_cli_preset/NextRec-CLI.md) and [reference code](/nextrec_cli_preset).
|
|
47
47
|
- **03/12/2025** NextRec reached 100 ⭐—thanks for the support!
|
|
48
48
|
- **06/12/2025** Added single-machine multi-GPU DDP training in v0.4.1 with supporting [code](tutorials/distributed).
|
|
49
|
-
- **23/11/2025** Major logical refactor of basemodel and unification of listwise/pairwise/pointwise losses in v0.2.2.
|
|
50
49
|
- **11/11/2025** NextRec v0.1.0 released with 10+ ranking models, 4 multi-task models, 4 retrieval models, and a unified training/logging/metrics system.
|
|
51
50
|
|
|
52
51
|
## Architecture
|
|
@@ -186,11 +185,11 @@ nextrec --mode=train --train_config=path/to/train_config.yaml
|
|
|
186
185
|
nextrec --mode=predict --predict_config=path/to/predict_config.yaml
|
|
187
186
|
```
|
|
188
187
|
|
|
189
|
-
> As of version 0.4.
|
|
188
|
+
> As of version 0.4.10, NextRec CLI supports single-machine training; distributed training features are currently under development.
|
|
190
189
|
|
|
191
190
|
## Platform Compatibility
|
|
192
191
|
|
|
193
|
-
The current version is 0.4.
|
|
192
|
+
The current version is 0.4.10. All models and test code have been validated on the following platforms. If you encounter compatibility issues, please report them in the issue tracker with your system version:
|
|
194
193
|
|
|
195
194
|
| Platform | Configuration |
|
|
196
195
|
|----------|---------------|
|
|
@@ -102,4 +102,4 @@ metrics = model.evaluate(
|
|
|
102
102
|
- Multi-task: `tutorials/example_multitask.py`
|
|
103
103
|
- Notebooks: `tutorials/notebooks/zh/Hands on nextrec.ipynb`, `tutorials/notebooks/zh/Hands on dataprocessor.ipynb`
|
|
104
104
|
|
|
105
|
-
For large offline features or streaming loads, use `DataProcessor` and `RecDataLoader` to configure CSV/Parquet paths and streaming (`
|
|
105
|
+
For large offline features or streaming loads, use `DataProcessor` and `RecDataLoader` to configure CSV/Parquet paths and streaming (`streaming=True`) without changing model code.
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
nextrec.utils package
|
|
2
|
+
=====================
|
|
3
|
+
|
|
4
|
+
Submodules
|
|
5
|
+
----------
|
|
6
|
+
|
|
7
|
+
nextrec.utils.embedding module
|
|
8
|
+
------------------------------
|
|
9
|
+
|
|
10
|
+
.. automodule:: nextrec.utils.embedding
|
|
11
|
+
:members:
|
|
12
|
+
:undoc-members:
|
|
13
|
+
:show-inheritance:
|
|
14
|
+
|
|
15
|
+
nextrec.utils.console module
|
|
16
|
+
----------------------------
|
|
17
|
+
|
|
18
|
+
.. automodule:: nextrec.utils.console
|
|
19
|
+
:members:
|
|
20
|
+
:undoc-members:
|
|
21
|
+
:show-inheritance:
|
|
22
|
+
|
|
23
|
+
nextrec.utils.data module
|
|
24
|
+
-------------------------
|
|
25
|
+
|
|
26
|
+
.. automodule:: nextrec.utils.data
|
|
27
|
+
:members:
|
|
28
|
+
:undoc-members:
|
|
29
|
+
:show-inheritance:
|
|
30
|
+
|
|
31
|
+
nextrec.utils.feature module
|
|
32
|
+
----------------------------
|
|
33
|
+
|
|
34
|
+
.. automodule:: nextrec.utils.feature
|
|
35
|
+
:members:
|
|
36
|
+
:undoc-members:
|
|
37
|
+
:show-inheritance:
|
|
38
|
+
|
|
39
|
+
nextrec.utils.model module
|
|
40
|
+
--------------------------
|
|
41
|
+
|
|
42
|
+
.. automodule:: nextrec.utils.model
|
|
43
|
+
:members:
|
|
44
|
+
:undoc-members:
|
|
45
|
+
:show-inheritance:
|
|
46
|
+
|
|
47
|
+
nextrec.utils.config module
|
|
48
|
+
---------------------------
|
|
49
|
+
|
|
50
|
+
.. automodule:: nextrec.utils.config
|
|
51
|
+
:members:
|
|
52
|
+
:undoc-members:
|
|
53
|
+
:show-inheritance:
|
|
54
|
+
|
|
55
|
+
nextrec.utils.torch_utils module
|
|
56
|
+
--------------------------------
|
|
57
|
+
|
|
58
|
+
.. automodule:: nextrec.utils.torch_utils
|
|
59
|
+
:members:
|
|
60
|
+
:undoc-members:
|
|
61
|
+
:show-inheritance:
|
|
62
|
+
|
|
63
|
+
Module contents
|
|
64
|
+
---------------
|
|
65
|
+
|
|
66
|
+
.. automodule:: nextrec.utils
|
|
67
|
+
:members:
|
|
68
|
+
:undoc-members:
|
|
69
|
+
:show-inheritance:
|
{nextrec-0.4.8 → nextrec-0.4.10}/docs/zh//345/277/253/351/200/237/344/270/212/346/211/213.md
RENAMED
|
@@ -102,4 +102,4 @@ metrics = model.evaluate(
|
|
|
102
102
|
- 多任务:`tutorials/example_multitask.py`
|
|
103
103
|
- Notebook:`tutorials/notebooks/zh/Hands on nextrec.ipynb`、`tutorials/notebooks/zh/Hands on dataprocessor.ipynb`
|
|
104
104
|
|
|
105
|
-
如果需要大规模离线特征或流式加载,可结合 `DataProcessor`、`RecDataLoader` 配置 CSV/Parquet 路径与流式参数(`
|
|
105
|
+
如果需要大规模离线特征或流式加载,可结合 `DataProcessor`、`RecDataLoader` 配置 CSV/Parquet 路径与流式参数(`streaming=True`),在不修改模型代码的情况下完成训练与推理。
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.4.10"
|
|
@@ -2,17 +2,20 @@
|
|
|
2
2
|
Callback System for Training Process
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 19/12/2025
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
import copy
|
|
10
10
|
import logging
|
|
11
|
-
|
|
11
|
+
import pickle
|
|
12
12
|
from pathlib import Path
|
|
13
|
+
from typing import Optional
|
|
14
|
+
|
|
13
15
|
import torch
|
|
14
|
-
|
|
16
|
+
|
|
15
17
|
from nextrec import __version__
|
|
18
|
+
from nextrec.basic.loggers import colorize, format_kv
|
|
16
19
|
|
|
17
20
|
|
|
18
21
|
class Callback:
|
|
@@ -209,8 +212,13 @@ class EarlyStopper(Callback):
|
|
|
209
212
|
if self.restore_best_weights and self.best_weights is not None:
|
|
210
213
|
if self.verbose > 0:
|
|
211
214
|
logging.info(
|
|
212
|
-
|
|
213
|
-
|
|
215
|
+
colorize(
|
|
216
|
+
format_kv(
|
|
217
|
+
"Restoring model weights from epoch",
|
|
218
|
+
f"{self.best_epoch + 1} with best {self.monitor}: {self.best_value:.6f}",
|
|
219
|
+
),
|
|
220
|
+
color="bright_blue",
|
|
221
|
+
)
|
|
214
222
|
)
|
|
215
223
|
self.model.load_state_dict(self.best_weights)
|
|
216
224
|
|
|
@@ -229,7 +237,8 @@ class CheckpointSaver(Callback):
|
|
|
229
237
|
|
|
230
238
|
def __init__(
|
|
231
239
|
self,
|
|
232
|
-
|
|
240
|
+
best_path: str | Path,
|
|
241
|
+
checkpoint_path: str | Path,
|
|
233
242
|
monitor: str = "val_auc",
|
|
234
243
|
mode: str = "max",
|
|
235
244
|
save_best_only: bool = False,
|
|
@@ -239,7 +248,8 @@ class CheckpointSaver(Callback):
|
|
|
239
248
|
):
|
|
240
249
|
super().__init__()
|
|
241
250
|
self.run_on_main_process_only = run_on_main_process_only
|
|
242
|
-
self.
|
|
251
|
+
self.best_path = Path(best_path)
|
|
252
|
+
self.checkpoint_path = Path(checkpoint_path)
|
|
243
253
|
self.monitor = monitor
|
|
244
254
|
self.mode = mode
|
|
245
255
|
self.save_best_only = save_best_only
|
|
@@ -260,14 +270,13 @@ class CheckpointSaver(Callback):
|
|
|
260
270
|
self.best_value = float("inf")
|
|
261
271
|
else:
|
|
262
272
|
self.best_value = float("-inf")
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
self.save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
273
|
+
self.best_path.parent.mkdir(parents=True, exist_ok=True)
|
|
274
|
+
self.checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
|
|
266
275
|
|
|
267
276
|
def on_epoch_end(self, epoch: int, logs: Optional[dict] = None):
|
|
277
|
+
logging.info("")
|
|
268
278
|
logs = logs or {}
|
|
269
279
|
|
|
270
|
-
# Check if we should save this epoch
|
|
271
280
|
should_save = False
|
|
272
281
|
if self.save_freq == "epoch":
|
|
273
282
|
should_save = True
|
|
@@ -289,17 +298,23 @@ class CheckpointSaver(Callback):
|
|
|
289
298
|
if should_save:
|
|
290
299
|
if not self.save_best_only or is_best:
|
|
291
300
|
checkpoint_path = (
|
|
292
|
-
self.
|
|
293
|
-
/ f"{self.
|
|
301
|
+
self.checkpoint_path.parent
|
|
302
|
+
/ f"{self.checkpoint_path.stem}{self.checkpoint_path.suffix}"
|
|
294
303
|
)
|
|
295
304
|
self.save_checkpoint(checkpoint_path, epoch, logs)
|
|
296
305
|
|
|
297
306
|
if is_best:
|
|
298
307
|
# Use save_path directly without adding _best suffix since it may already contain it
|
|
299
|
-
self.save_checkpoint(self.
|
|
308
|
+
self.save_checkpoint(self.best_path, epoch, logs)
|
|
300
309
|
if self.verbose > 0:
|
|
301
310
|
logging.info(
|
|
302
|
-
|
|
311
|
+
colorize(
|
|
312
|
+
format_kv(
|
|
313
|
+
"Saved best model to",
|
|
314
|
+
f"{self.best_path} with {self.monitor}: {current:.6f}",
|
|
315
|
+
),
|
|
316
|
+
color="bright_blue",
|
|
317
|
+
)
|
|
303
318
|
)
|
|
304
319
|
|
|
305
320
|
def save_checkpoint(self, path: Path, epoch: int, logs: dict):
|
|
@@ -2,22 +2,22 @@
|
|
|
2
2
|
Layer implementations used across NextRec models.
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 19/12/2025
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
from __future__ import annotations
|
|
10
10
|
|
|
11
|
+
from collections import OrderedDict
|
|
12
|
+
from itertools import combinations
|
|
13
|
+
|
|
11
14
|
import torch
|
|
12
15
|
import torch.nn as nn
|
|
13
16
|
import torch.nn.functional as F
|
|
14
17
|
|
|
15
|
-
from itertools import combinations
|
|
16
|
-
from collections import OrderedDict
|
|
17
|
-
|
|
18
|
-
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
19
|
-
from nextrec.utils.initializer import get_initializer
|
|
20
18
|
from nextrec.basic.activation import activation_layer
|
|
19
|
+
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
20
|
+
from nextrec.utils.torch_utils import get_initializer
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
class PredictionLayer(nn.Module):
|
|
@@ -81,8 +81,6 @@ class PredictionLayer(nn.Module):
|
|
|
81
81
|
outputs.append(torch.sigmoid(task_logits))
|
|
82
82
|
elif task == "regression":
|
|
83
83
|
outputs.append(task_logits)
|
|
84
|
-
elif task == "multiclass":
|
|
85
|
-
outputs.append(torch.softmax(task_logits, dim=-1))
|
|
86
84
|
else:
|
|
87
85
|
raise ValueError(
|
|
88
86
|
f"[PredictionLayer Error]: Unsupported task_type '{task_type}'."
|
|
@@ -2,20 +2,20 @@
|
|
|
2
2
|
NextRec Basic Loggers
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 19/12/2025
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
|
-
import os
|
|
10
|
-
import re
|
|
11
|
-
import sys
|
|
12
|
-
import json
|
|
13
9
|
import copy
|
|
10
|
+
import json
|
|
14
11
|
import logging
|
|
15
12
|
import numbers
|
|
13
|
+
import os
|
|
14
|
+
import re
|
|
15
|
+
import sys
|
|
16
|
+
from typing import Any, Mapping
|
|
16
17
|
|
|
17
|
-
from
|
|
18
|
-
from nextrec.basic.session import create_session, Session
|
|
18
|
+
from nextrec.basic.session import Session, create_session
|
|
19
19
|
|
|
20
20
|
ANSI_CODES = {
|
|
21
21
|
"black": "\033[30m",
|
|
@@ -91,6 +91,13 @@ def colorize(text: str, color: str | None = None, bold: bool = False) -> str:
|
|
|
91
91
|
return result
|
|
92
92
|
|
|
93
93
|
|
|
94
|
+
def format_kv(label: str, value: Any, width: int = 34, indent: int = 0) -> str:
|
|
95
|
+
"""Format key-value lines with consistent alignment."""
|
|
96
|
+
label_text = label if label.endswith(":") else f"{label}:"
|
|
97
|
+
prefix = " " * indent
|
|
98
|
+
return f"{prefix}{label_text:<{width}} {value}"
|
|
99
|
+
|
|
100
|
+
|
|
94
101
|
def setup_logger(session_id: str | os.PathLike | None = None):
|
|
95
102
|
"""Set up a logger that logs to both console and a file with ANSI formatting.
|
|
96
103
|
Only console output has colors; file output is stripped of ANSI codes.
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Metrics computation and configuration for model evaluation.
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 19/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -11,15 +11,15 @@ from typing import Any
|
|
|
11
11
|
|
|
12
12
|
import numpy as np
|
|
13
13
|
from sklearn.metrics import (
|
|
14
|
-
|
|
14
|
+
accuracy_score,
|
|
15
|
+
f1_score,
|
|
15
16
|
log_loss,
|
|
16
|
-
mean_squared_error,
|
|
17
17
|
mean_absolute_error,
|
|
18
|
-
|
|
18
|
+
mean_squared_error,
|
|
19
19
|
precision_score,
|
|
20
|
-
recall_score,
|
|
21
|
-
f1_score,
|
|
22
20
|
r2_score,
|
|
21
|
+
recall_score,
|
|
22
|
+
roc_auc_score,
|
|
23
23
|
)
|
|
24
24
|
|
|
25
25
|
CLASSIFICATION_METRICS = {
|
|
@@ -44,11 +44,6 @@ TASK_DEFAULT_METRICS = {
|
|
|
44
44
|
+ [f"recall@{k}" for k in (5, 10, 20)]
|
|
45
45
|
+ [f"ndcg@{k}" for k in (5, 10, 20)]
|
|
46
46
|
+ [f"mrr@{k}" for k in (5, 10, 20)],
|
|
47
|
-
# generative/multiclass next-item prediction defaults
|
|
48
|
-
"multiclass": ["accuracy"]
|
|
49
|
-
+ [f"hitrate@{k}" for k in (1, 5, 10)]
|
|
50
|
-
+ [f"recall@{k}" for k in (1, 5, 10)]
|
|
51
|
-
+ [f"mrr@{k}" for k in (1, 5, 10)],
|
|
52
47
|
}
|
|
53
48
|
|
|
54
49
|
|
|
@@ -163,51 +158,6 @@ def group_indices_by_user(user_ids: np.ndarray, n_samples: int) -> list[np.ndarr
|
|
|
163
158
|
return groups
|
|
164
159
|
|
|
165
160
|
|
|
166
|
-
def normalize_multiclass_inputs(
|
|
167
|
-
y_true: np.ndarray, y_pred: np.ndarray
|
|
168
|
-
) -> tuple[np.ndarray, np.ndarray]:
|
|
169
|
-
"""
|
|
170
|
-
Normalize multiclass inputs to consistent shapes.
|
|
171
|
-
|
|
172
|
-
y_true: [N] of class ids
|
|
173
|
-
y_pred: [N, C] of logits/probabilities
|
|
174
|
-
"""
|
|
175
|
-
labels = np.asarray(y_true).reshape(-1)
|
|
176
|
-
scores = np.asarray(y_pred)
|
|
177
|
-
if scores.ndim == 1:
|
|
178
|
-
scores = scores.reshape(scores.shape[0], -1)
|
|
179
|
-
if scores.shape[0] != labels.shape[0]:
|
|
180
|
-
raise ValueError(
|
|
181
|
-
f"[Metric Warning] y_true length {labels.shape[0]} != y_pred batch {scores.shape[0]} for multiclass metrics."
|
|
182
|
-
)
|
|
183
|
-
return labels.astype(int), scores
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
def multiclass_topk_hit_rate(y_true: np.ndarray, y_pred: np.ndarray, k: int) -> float:
|
|
187
|
-
labels, scores = normalize_multiclass_inputs(y_true, y_pred)
|
|
188
|
-
if scores.shape[1] == 0:
|
|
189
|
-
return 0.0
|
|
190
|
-
k = min(k, scores.shape[1])
|
|
191
|
-
topk_idx = np.argpartition(-scores, kth=k - 1, axis=1)[:, :k]
|
|
192
|
-
hits = (topk_idx == labels[:, None]).any(axis=1)
|
|
193
|
-
return float(hits.mean()) if hits.size > 0 else 0.0
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
def multiclass_mrr_at_k(y_true: np.ndarray, y_pred: np.ndarray, k: int) -> float:
|
|
197
|
-
labels, scores = normalize_multiclass_inputs(y_true, y_pred)
|
|
198
|
-
if scores.shape[1] == 0:
|
|
199
|
-
return 0.0
|
|
200
|
-
k = min(k, scores.shape[1])
|
|
201
|
-
# full sort for stable ranks
|
|
202
|
-
topk_idx = np.argsort(-scores, axis=1)[:, :k]
|
|
203
|
-
ranks = np.full(labels.shape, fill_value=k + 1, dtype=np.float32)
|
|
204
|
-
for idx in range(k):
|
|
205
|
-
match = topk_idx[:, idx] == labels
|
|
206
|
-
ranks[match] = idx + 1
|
|
207
|
-
reciprocals = np.where(ranks <= k, 1.0 / ranks, 0.0)
|
|
208
|
-
return float(reciprocals.mean()) if reciprocals.size > 0 else 0.0
|
|
209
|
-
|
|
210
|
-
|
|
211
161
|
def compute_precision_at_k(
|
|
212
162
|
y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int
|
|
213
163
|
) -> float:
|
|
@@ -514,26 +464,6 @@ def compute_single_metric(
|
|
|
514
464
|
"""Compute a single metric given true and predicted values."""
|
|
515
465
|
y_p_binary = (y_pred > 0.5).astype(int)
|
|
516
466
|
metric_lower = metric.lower()
|
|
517
|
-
is_multiclass = task_type == "multiclass" and y_pred.ndim >= 2
|
|
518
|
-
if is_multiclass:
|
|
519
|
-
# Dedicated path for multiclass logits (e.g., next-item prediction)
|
|
520
|
-
labels, scores = normalize_multiclass_inputs(y_true, y_pred)
|
|
521
|
-
if metric_lower in ("accuracy", "acc"):
|
|
522
|
-
preds = scores.argmax(axis=1)
|
|
523
|
-
return float((preds == labels).mean())
|
|
524
|
-
if metric_lower.startswith("hitrate@") or metric_lower.startswith("hr@"):
|
|
525
|
-
k_str = metric_lower.split("@")[1]
|
|
526
|
-
k = int(k_str)
|
|
527
|
-
return multiclass_topk_hit_rate(labels, scores, k)
|
|
528
|
-
if metric_lower.startswith("recall@"):
|
|
529
|
-
k = int(metric_lower.split("@")[1])
|
|
530
|
-
return multiclass_topk_hit_rate(labels, scores, k)
|
|
531
|
-
if metric_lower.startswith("mrr@"):
|
|
532
|
-
k = int(metric_lower.split("@")[1])
|
|
533
|
-
return multiclass_mrr_at_k(labels, scores, k)
|
|
534
|
-
# fall back to accuracy if unsupported metric is requested
|
|
535
|
-
preds = scores.argmax(axis=1)
|
|
536
|
-
return float((preds == labels).mean())
|
|
537
467
|
try:
|
|
538
468
|
if metric_lower.startswith("recall@"):
|
|
539
469
|
k = int(metric_lower.split("@")[1])
|