nextrec 0.4.22__tar.gz → 0.4.23__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nextrec-0.4.22 → nextrec-0.4.23}/PKG-INFO +7 -5
- {nextrec-0.4.22 → nextrec-0.4.23}/README.md +6 -4
- {nextrec-0.4.22 → nextrec-0.4.23}/README_en.md +4 -4
- {nextrec-0.4.22 → nextrec-0.4.23}/docs/en/Getting started guide.md +1 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/docs/rtd/conf.py +1 -1
- {nextrec-0.4.22 → nextrec-0.4.23}/docs/rtd/index.md +1 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/docs/zh//345/277/253/351/200/237/344/270/212/346/211/213.md +1 -0
- nextrec-0.4.23/nextrec/__version__.py +1 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/basic/metrics.py +1 -2
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/basic/model.py +68 -73
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/basic/summary.py +36 -2
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/data/preprocessor.py +137 -5
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/loss/listwise.py +19 -6
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/loss/pairwise.py +6 -4
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/loss/pointwise.py +8 -6
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/multi_task/esmm.py +3 -26
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/multi_task/mmoe.py +2 -24
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/multi_task/ple.py +13 -35
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/multi_task/poso.py +4 -28
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/multi_task/share_bottom.py +1 -24
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/afm.py +3 -27
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/autoint.py +5 -38
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/dcn.py +1 -26
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/dcn_v2.py +5 -33
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/deepfm.py +2 -29
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/dien.py +2 -28
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/din.py +2 -27
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/eulernet.py +3 -30
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/ffm.py +0 -26
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/fibinet.py +8 -32
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/fm.py +0 -29
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/lr.py +0 -30
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/masknet.py +4 -30
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/pnn.py +4 -28
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/widedeep.py +0 -32
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/xdeepfm.py +0 -30
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/retrieval/dssm.py +0 -24
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/retrieval/dssm_v2.py +0 -24
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/retrieval/mind.py +0 -20
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/retrieval/sdm.py +0 -20
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/retrieval/youtube_dnn.py +0 -21
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/sequential/hstu.py +0 -18
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/utils/model.py +79 -1
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/utils/types.py +35 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/pyproject.toml +1 -1
- {nextrec-0.4.22 → nextrec-0.4.23}/test/test_ranking_models.py +0 -3
- {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/example_multitask.py +1 -8
- {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/example_ranking_din.py +3 -5
- {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/notebooks/en/Hands on nextrec.ipynb +1 -1
- {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/notebooks/zh//345/277/253/351/200/237/345/205/245/351/227/250nextrec.ipynb +1 -1
- nextrec-0.4.22/nextrec/__version__.py +0 -1
- {nextrec-0.4.22 → nextrec-0.4.23}/.github/workflows/publish.yml +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/.github/workflows/tests.yml +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/.gitignore +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/.readthedocs.yaml +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/CODE_OF_CONDUCT.md +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/CONTRIBUTING.md +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/LICENSE +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/MANIFEST.in +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/assets/Feature Configuration.png +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/assets/Model Parameters.png +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/assets/Training Configuration.png +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/assets/Training logs.png +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/assets/logo.png +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/assets/mmoe_tutorial.png +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/assets/nextrec_diagram.png +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/assets/test data.png +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/dataset/ctcvr_task.csv +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/dataset/ecommerce_task.csv +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/dataset/match_task.csv +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/dataset/movielens_100k.csv +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/dataset/multitask_task.csv +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/dataset/ranking_task.csv +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/docs/rtd/Makefile +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/docs/rtd/make.bat +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/docs/rtd/modules.rst +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/docs/rtd/nextrec.basic.rst +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/docs/rtd/nextrec.data.rst +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/docs/rtd/nextrec.loss.rst +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/docs/rtd/nextrec.rst +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/docs/rtd/nextrec.utils.rst +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/docs/rtd/requirements.txt +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/__init__.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/basic/__init__.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/basic/activation.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/basic/callback.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/basic/features.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/basic/heads.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/basic/layers.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/basic/loggers.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/basic/session.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/cli.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/data/__init__.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/data/batch_utils.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/data/data_processing.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/data/data_utils.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/data/dataloader.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/loss/__init__.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/loss/grad_norm.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/generative/__init__.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/generative/tiger.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/multi_task/__init__.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/ranking/__init__.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/representation/__init__.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/representation/autorec.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/representation/bpr.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/representation/cl4srec.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/representation/lightgcn.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/representation/mf.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/representation/rqvae.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/representation/s3rec.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/retrieval/__init__.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/models/sequential/sasrec.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/utils/__init__.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/utils/config.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/utils/console.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/utils/data.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/utils/embedding.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/utils/feature.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/utils/loss.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec/utils/torch_utils.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/NextRec-CLI.md +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/NextRec-CLI_zh.md +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/feature_config.yaml +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/afm.yaml +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/autoint.yaml +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/dcn.yaml +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/deepfm.yaml +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/din.yaml +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/esmm.yaml +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/fibinet.yaml +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/fm.yaml +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/masknet.yaml +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/mmoe.yaml +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/ple.yaml +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/pnn.yaml +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/poso.yaml +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/share_bottom.yaml +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/widedeep.yaml +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/xdeepfm.yaml +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/predict_config.yaml +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/predict_config_template.yaml +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/train_config.yaml +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/nextrec_cli_preset/train_config_template.yaml +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/pytest.ini +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/requirements.txt +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/scripts/format_code.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/test/__init__.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/test/conftest.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/test/helpers.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/test/run_tests.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/test/test_base_model_regularization.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/test/test_generative_models.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/test/test_layers.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/test/test_losses.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/test/test_match_models.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/test/test_multitask_models.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/test/test_preprocessor.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/test/test_utils_console.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/test/test_utils_data.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/test/test_utils_embedding.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/test_requirements.txt +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/distributed/example_distributed_training.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/distributed/example_distributed_training_large_dataset.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/example_match.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/movielen_match_dssm.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/movielen_ranking_deepfm.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/notebooks/en/Build semantic ID with RQ-VAE.ipynb +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/notebooks/en/Hands on dataprocessor.ipynb +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/notebooks/zh//344/275/277/347/224/250RQ-VAE/346/236/204/345/273/272/350/257/255/344/271/211ID.ipynb" +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/notebooks/zh//345/246/202/344/275/225/344/275/277/347/224/250DataProcessor/350/277/233/350/241/214/351/242/204/345/244/204/347/220/206.ipynb" +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/run_all_match_models.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/run_all_multitask_models.py +0 -0
- {nextrec-0.4.22 → nextrec-0.4.23}/tutorials/run_all_ranking_models.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nextrec
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.23
|
|
4
4
|
Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
|
|
5
5
|
Project-URL: Homepage, https://github.com/zerolovesea/NextRec
|
|
6
6
|
Project-URL: Repository, https://github.com/zerolovesea/NextRec
|
|
@@ -69,7 +69,7 @@ Description-Content-Type: text/markdown
|
|
|
69
69
|

|
|
70
70
|

|
|
71
71
|

|
|
72
|
-

|
|
73
73
|
[](https://deepwiki.com/zerolovesea/NextRec)
|
|
74
74
|
|
|
75
75
|
中文文档 | [English Version](README_en.md)
|
|
@@ -191,6 +191,8 @@ model = DIN(
|
|
|
191
191
|
dense_features=dense_features,
|
|
192
192
|
sparse_features=sparse_features,
|
|
193
193
|
sequence_features=sequence_features,
|
|
194
|
+
behavior_feature_name="sequence_0",
|
|
195
|
+
candidate_feature_name="item_id",
|
|
194
196
|
mlp_params=mlp_params,
|
|
195
197
|
attention_hidden_units=[80, 40],
|
|
196
198
|
attention_activation='sigmoid',
|
|
@@ -204,7 +206,7 @@ model = DIN(
|
|
|
204
206
|
session_id="din_tutorial", # 实验id,用于存放训练日志
|
|
205
207
|
)
|
|
206
208
|
|
|
207
|
-
#
|
|
209
|
+
# 编译模型,优化器/损失/学习率调度器统一在 compile 中设置
|
|
208
210
|
model.compile(
|
|
209
211
|
optimizer = "adam",
|
|
210
212
|
optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5},
|
|
@@ -247,11 +249,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
|
|
|
247
249
|
|
|
248
250
|
预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
|
|
249
251
|
|
|
250
|
-
> 截止当前版本0.4.
|
|
252
|
+
> 截止当前版本0.4.23,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
|
|
251
253
|
|
|
252
254
|
## 兼容平台
|
|
253
255
|
|
|
254
|
-
当前最新版本为0.4.
|
|
256
|
+
当前最新版本为0.4.23,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
|
|
255
257
|
|
|
256
258
|
| 平台 | 配置 |
|
|
257
259
|
|------|------|
|
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|

|
|
9
9
|

|
|
10
10
|

|
|
11
|
-

|
|
12
12
|
[](https://deepwiki.com/zerolovesea/NextRec)
|
|
13
13
|
|
|
14
14
|
中文文档 | [English Version](README_en.md)
|
|
@@ -130,6 +130,8 @@ model = DIN(
|
|
|
130
130
|
dense_features=dense_features,
|
|
131
131
|
sparse_features=sparse_features,
|
|
132
132
|
sequence_features=sequence_features,
|
|
133
|
+
behavior_feature_name="sequence_0",
|
|
134
|
+
candidate_feature_name="item_id",
|
|
133
135
|
mlp_params=mlp_params,
|
|
134
136
|
attention_hidden_units=[80, 40],
|
|
135
137
|
attention_activation='sigmoid',
|
|
@@ -143,7 +145,7 @@ model = DIN(
|
|
|
143
145
|
session_id="din_tutorial", # 实验id,用于存放训练日志
|
|
144
146
|
)
|
|
145
147
|
|
|
146
|
-
#
|
|
148
|
+
# 编译模型,优化器/损失/学习率调度器统一在 compile 中设置
|
|
147
149
|
model.compile(
|
|
148
150
|
optimizer = "adam",
|
|
149
151
|
optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5},
|
|
@@ -186,11 +188,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
|
|
|
186
188
|
|
|
187
189
|
预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
|
|
188
190
|
|
|
189
|
-
> 截止当前版本0.4.
|
|
191
|
+
> 截止当前版本0.4.23,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
|
|
190
192
|
|
|
191
193
|
## 兼容平台
|
|
192
194
|
|
|
193
|
-
当前最新版本为0.4.
|
|
195
|
+
当前最新版本为0.4.23,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
|
|
194
196
|
|
|
195
197
|
| 平台 | 配置 |
|
|
196
198
|
|------|------|
|
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|

|
|
9
9
|

|
|
10
10
|

|
|
11
|
-

|
|
12
12
|
[](https://deepwiki.com/zerolovesea/NextRec)
|
|
13
13
|
|
|
14
14
|
English | [中文文档](README.md)
|
|
@@ -148,7 +148,7 @@ model = DIN(
|
|
|
148
148
|
session_id="din_tutorial", # experiment id for logs
|
|
149
149
|
)
|
|
150
150
|
|
|
151
|
-
# Compile model
|
|
151
|
+
# Compile model; configure optimizer/loss/scheduler via compile()
|
|
152
152
|
model.compile(
|
|
153
153
|
optimizer = "adam",
|
|
154
154
|
optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5},
|
|
@@ -191,11 +191,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
|
|
|
191
191
|
|
|
192
192
|
Prediction outputs are saved under `{checkpoint_path}/predictions/{name}.{save_data_format}`.
|
|
193
193
|
|
|
194
|
-
> As of version 0.4.
|
|
194
|
+
> As of version 0.4.23, NextRec CLI supports single-machine training; distributed training features are currently under development.
|
|
195
195
|
|
|
196
196
|
## Platform Compatibility
|
|
197
197
|
|
|
198
|
-
The current version is 0.4.
|
|
198
|
+
The current version is 0.4.23. All models and test code have been validated on the following platforms. If you encounter compatibility issues, please report them in the issue tracker with your system version:
|
|
199
199
|
|
|
200
200
|
| Platform | Configuration |
|
|
201
201
|
|----------|---------------|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.4.23"
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Metrics computation and configuration for model evaluation.
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 29/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -39,7 +39,6 @@ REGRESSION_METRICS = {"mse", "mae", "rmse", "r2", "mape", "msle"}
|
|
|
39
39
|
TASK_DEFAULT_METRICS = {
|
|
40
40
|
"binary": ["auc", "gauc", "ks", "logloss", "accuracy", "precision", "recall", "f1"],
|
|
41
41
|
"regression": ["mse", "mae", "rmse", "r2", "mape"],
|
|
42
|
-
"multilabel": ["auc", "hamming_loss", "subset_accuracy", "micro_f1", "macro_f1"],
|
|
43
42
|
"matching": ["auc", "gauc", "precision@10", "hitrate@10", "map@10", "cosine"]
|
|
44
43
|
+ [f"recall@{k}" for k in (5, 10, 20)]
|
|
45
44
|
+ [f"ndcg@{k}" for k in (5, 10, 20)]
|
|
@@ -2,13 +2,14 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 29/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
import getpass
|
|
10
10
|
import logging
|
|
11
11
|
import os
|
|
12
|
+
import sys
|
|
12
13
|
import pickle
|
|
13
14
|
import socket
|
|
14
15
|
from pathlib import Path
|
|
@@ -16,6 +17,16 @@ from typing import Any, Literal
|
|
|
16
17
|
|
|
17
18
|
import numpy as np
|
|
18
19
|
import pandas as pd
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
import swanlab # type: ignore
|
|
23
|
+
except ModuleNotFoundError:
|
|
24
|
+
swanlab = None
|
|
25
|
+
try:
|
|
26
|
+
import wandb # type: ignore
|
|
27
|
+
except ModuleNotFoundError:
|
|
28
|
+
wandb = None
|
|
29
|
+
|
|
19
30
|
import torch
|
|
20
31
|
import torch.distributed as dist
|
|
21
32
|
import torch.nn as nn
|
|
@@ -74,13 +85,19 @@ from nextrec.utils.torch_utils import (
|
|
|
74
85
|
to_tensor,
|
|
75
86
|
)
|
|
76
87
|
from nextrec.utils.config import safe_value
|
|
77
|
-
from nextrec.utils.model import
|
|
88
|
+
from nextrec.utils.model import (
|
|
89
|
+
compute_ranking_loss,
|
|
90
|
+
get_loss_list,
|
|
91
|
+
resolve_loss_weights,
|
|
92
|
+
get_training_modes,
|
|
93
|
+
)
|
|
78
94
|
from nextrec.utils.types import (
|
|
79
95
|
LossName,
|
|
80
96
|
OptimizerName,
|
|
81
97
|
SchedulerName,
|
|
82
98
|
TrainingModeName,
|
|
83
99
|
TaskTypeName,
|
|
100
|
+
MetricsName,
|
|
84
101
|
)
|
|
85
102
|
|
|
86
103
|
|
|
@@ -90,7 +107,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
90
107
|
raise NotImplementedError
|
|
91
108
|
|
|
92
109
|
@property
|
|
93
|
-
def default_task(self) ->
|
|
110
|
+
def default_task(self) -> TaskTypeName | list[TaskTypeName]:
|
|
94
111
|
raise NotImplementedError
|
|
95
112
|
|
|
96
113
|
def __init__(
|
|
@@ -139,6 +156,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
139
156
|
world_size: Number of processes (defaults to env WORLD_SIZE).
|
|
140
157
|
local_rank: Local rank for selecting CUDA device (defaults to env LOCAL_RANK).
|
|
141
158
|
ddp_find_unused_parameters: Default False, set it True only when exist unused parameters in ddp model, in most cases should be False.
|
|
159
|
+
|
|
160
|
+
Note:
|
|
161
|
+
Optimizer, scheduler, and loss are configured via compile().
|
|
142
162
|
"""
|
|
143
163
|
super(BaseModel, self).__init__()
|
|
144
164
|
|
|
@@ -171,24 +191,12 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
171
191
|
dense_features, sparse_features, sequence_features, target, id_columns
|
|
172
192
|
)
|
|
173
193
|
|
|
174
|
-
self.task = self.default_task
|
|
194
|
+
self.task = task or self.default_task
|
|
175
195
|
self.nums_task = len(self.task) if isinstance(self.task, list) else 1
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
if
|
|
179
|
-
|
|
180
|
-
"[BaseModel-init Error] training_mode list length must match number of tasks."
|
|
181
|
-
)
|
|
182
|
-
else:
|
|
183
|
-
training_modes = [training_mode] * self.nums_task
|
|
184
|
-
if any(
|
|
185
|
-
mode not in {"pointwise", "pairwise", "listwise"} for mode in training_modes
|
|
186
|
-
):
|
|
187
|
-
raise ValueError(
|
|
188
|
-
"[BaseModel-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
|
|
189
|
-
)
|
|
190
|
-
self.training_modes = training_modes
|
|
191
|
-
self.training_mode = training_modes if self.nums_task > 1 else training_modes[0]
|
|
196
|
+
self.training_modes = get_training_modes(training_mode, self.nums_task)
|
|
197
|
+
self.training_mode = (
|
|
198
|
+
self.training_modes if self.nums_task > 1 else self.training_modes[0]
|
|
199
|
+
)
|
|
192
200
|
|
|
193
201
|
self.embedding_l1_reg = embedding_l1_reg
|
|
194
202
|
self.dense_l1_reg = dense_l1_reg
|
|
@@ -196,8 +204,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
196
204
|
self.dense_l2_reg = dense_l2_reg
|
|
197
205
|
self.regularization_weights = []
|
|
198
206
|
self.embedding_params = []
|
|
199
|
-
|
|
207
|
+
|
|
200
208
|
self.ignore_label = None
|
|
209
|
+
self.compiled = False
|
|
201
210
|
|
|
202
211
|
self.max_gradient_norm = 1.0
|
|
203
212
|
self.logger_initialized = False
|
|
@@ -431,28 +440,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
431
440
|
"pairwise": "bpr",
|
|
432
441
|
"listwise": "listnet",
|
|
433
442
|
}
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
elif isinstance(effective_loss, list):
|
|
438
|
-
if not effective_loss:
|
|
439
|
-
loss_list = [default_losses[mode] for mode in self.training_modes]
|
|
440
|
-
else:
|
|
441
|
-
if len(effective_loss) != self.nums_task:
|
|
442
|
-
raise ValueError(
|
|
443
|
-
f"[BaseModel-compile Error] Number of loss functions ({len(effective_loss)}) must match number of tasks ({self.nums_task})."
|
|
444
|
-
)
|
|
445
|
-
loss_list = list(effective_loss)
|
|
446
|
-
else:
|
|
447
|
-
loss_list = [effective_loss] * self.nums_task
|
|
448
|
-
|
|
449
|
-
for idx, mode in enumerate(self.training_modes):
|
|
450
|
-
if isinstance(loss_list[idx], str) and loss_list[idx] in {
|
|
451
|
-
"bce",
|
|
452
|
-
"binary_crossentropy",
|
|
453
|
-
}:
|
|
454
|
-
if mode in {"pairwise", "listwise"}:
|
|
455
|
-
loss_list[idx] = default_losses[mode]
|
|
443
|
+
loss_list = get_loss_list(
|
|
444
|
+
loss, self.training_modes, self.nums_task, default_losses
|
|
445
|
+
)
|
|
456
446
|
self.loss_params = loss_params or {}
|
|
457
447
|
optimizer_params = optimizer_params or {}
|
|
458
448
|
self.optimizer_name = (
|
|
@@ -516,30 +506,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
516
506
|
nums_task=self.nums_task, device=self.device, **grad_norm_params
|
|
517
507
|
)
|
|
518
508
|
self.loss_weights = None
|
|
519
|
-
elif loss_weights is None:
|
|
520
|
-
self.loss_weights = None
|
|
521
|
-
elif self.nums_task == 1:
|
|
522
|
-
if isinstance(loss_weights, (list, tuple)):
|
|
523
|
-
if len(loss_weights) != 1:
|
|
524
|
-
raise ValueError(
|
|
525
|
-
"[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup."
|
|
526
|
-
)
|
|
527
|
-
loss_weights = loss_weights[0]
|
|
528
|
-
self.loss_weights = [float(loss_weights)] # type: ignore
|
|
529
509
|
else:
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
elif isinstance(loss_weights, (list, tuple)):
|
|
533
|
-
weights = [float(w) for w in loss_weights]
|
|
534
|
-
if len(weights) != self.nums_task:
|
|
535
|
-
raise ValueError(
|
|
536
|
-
f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task})."
|
|
537
|
-
)
|
|
538
|
-
else:
|
|
539
|
-
raise TypeError(
|
|
540
|
-
f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}"
|
|
541
|
-
)
|
|
542
|
-
self.loss_weights = weights
|
|
510
|
+
self.loss_weights = resolve_loss_weights(loss_weights, self.nums_task)
|
|
511
|
+
self.compiled = True
|
|
543
512
|
|
|
544
513
|
def compute_loss(self, y_pred, y_true):
|
|
545
514
|
if y_true is None:
|
|
@@ -602,9 +571,6 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
602
571
|
for i, (start, end) in enumerate(slices): # type: ignore
|
|
603
572
|
y_pred_i = y_pred[:, start:end]
|
|
604
573
|
y_true_i = y_true[:, start:end]
|
|
605
|
-
total_count = y_true_i.shape[0]
|
|
606
|
-
# valid_count = None
|
|
607
|
-
|
|
608
574
|
# mask ignored labels
|
|
609
575
|
if self.ignore_label is not None:
|
|
610
576
|
valid_mask = y_true_i != self.ignore_label
|
|
@@ -613,11 +579,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
613
579
|
if not torch.any(valid_mask):
|
|
614
580
|
task_losses.append(y_pred_i.sum() * 0.0)
|
|
615
581
|
continue
|
|
616
|
-
# valid_count = valid_mask.sum().to(dtype=y_true_i.dtype)
|
|
617
582
|
y_pred_i = y_pred_i[valid_mask]
|
|
618
583
|
y_true_i = y_true_i[valid_mask]
|
|
619
|
-
# else:
|
|
620
|
-
# valid_count = y_true_i.new_tensor(float(total_count))
|
|
621
584
|
|
|
622
585
|
mode = self.training_modes[i]
|
|
623
586
|
|
|
@@ -691,7 +654,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
691
654
|
train_data=None,
|
|
692
655
|
valid_data=None,
|
|
693
656
|
metrics: (
|
|
694
|
-
list[
|
|
657
|
+
list[MetricsName] | dict[str, list[MetricsName]] | None
|
|
695
658
|
) = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
|
|
696
659
|
epochs: int = 1,
|
|
697
660
|
shuffle: bool = True,
|
|
@@ -705,6 +668,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
705
668
|
use_tensorboard: bool = True,
|
|
706
669
|
use_wandb: bool = False,
|
|
707
670
|
use_swanlab: bool = False,
|
|
671
|
+
wandb_api: str | None = None,
|
|
672
|
+
swanlab_api: str | None = None,
|
|
708
673
|
wandb_kwargs: dict | None = None,
|
|
709
674
|
swanlab_kwargs: dict | None = None,
|
|
710
675
|
auto_ddp_sampler: bool = True,
|
|
@@ -734,6 +699,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
734
699
|
use_tensorboard: Enable tensorboard logging.
|
|
735
700
|
use_wandb: Enable Weights & Biases logging.
|
|
736
701
|
use_swanlab: Enable SwanLab logging.
|
|
702
|
+
wandb_api: W&B API key for non-tty login.
|
|
703
|
+
swanlab_api: SwanLab API key for non-tty login.
|
|
737
704
|
wandb_kwargs: Optional kwargs for wandb.init(...).
|
|
738
705
|
swanlab_kwargs: Optional kwargs for swanlab.init(...).
|
|
739
706
|
auto_ddp_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
|
|
@@ -751,6 +718,16 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
751
718
|
)
|
|
752
719
|
self.to(self.device)
|
|
753
720
|
|
|
721
|
+
if not self.compiled:
|
|
722
|
+
self.compile(
|
|
723
|
+
optimizer="adam",
|
|
724
|
+
optimizer_params={},
|
|
725
|
+
scheduler=None,
|
|
726
|
+
scheduler_params={},
|
|
727
|
+
loss=None,
|
|
728
|
+
loss_params={},
|
|
729
|
+
)
|
|
730
|
+
|
|
754
731
|
if (
|
|
755
732
|
self.distributed
|
|
756
733
|
and dist.is_available()
|
|
@@ -825,6 +802,24 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
825
802
|
}
|
|
826
803
|
training_config: dict = safe_value(training_config) # type: ignore
|
|
827
804
|
|
|
805
|
+
if self.is_main_process:
|
|
806
|
+
is_tty = sys.stdin.isatty() and sys.stdout.isatty()
|
|
807
|
+
if not is_tty:
|
|
808
|
+
if use_wandb and wandb_api:
|
|
809
|
+
if wandb is None:
|
|
810
|
+
logging.warning(
|
|
811
|
+
"[BaseModel-fit] wandb not installed, skip wandb login."
|
|
812
|
+
)
|
|
813
|
+
else:
|
|
814
|
+
wandb.login(key=wandb_api)
|
|
815
|
+
if use_swanlab and swanlab_api:
|
|
816
|
+
if swanlab is None:
|
|
817
|
+
logging.warning(
|
|
818
|
+
"[BaseModel-fit] swanlab not installed, skip swanlab login."
|
|
819
|
+
)
|
|
820
|
+
else:
|
|
821
|
+
swanlab.login(api_key=swanlab_api)
|
|
822
|
+
|
|
828
823
|
self.training_logger = (
|
|
829
824
|
TrainingLogger(
|
|
830
825
|
session=self.session,
|
|
@@ -1,5 +1,9 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Summary utilities for BaseModel.
|
|
3
|
+
|
|
4
|
+
Date: create on 03/12/2025
|
|
5
|
+
Checkpoint: edit on 29/12/2025
|
|
6
|
+
Author: Yang Zhou,zyaztec@gmail.com
|
|
3
7
|
"""
|
|
4
8
|
|
|
5
9
|
from __future__ import annotations
|
|
@@ -12,9 +16,39 @@ from torch.utils.data import DataLoader
|
|
|
12
16
|
|
|
13
17
|
from nextrec.basic.loggers import colorize, format_kv
|
|
14
18
|
from nextrec.data.data_processing import extract_label_arrays, get_data_length
|
|
19
|
+
from nextrec.utils.types import TaskTypeName
|
|
15
20
|
|
|
16
21
|
|
|
17
22
|
class SummarySet:
|
|
23
|
+
model_name: str
|
|
24
|
+
dense_features: list[Any]
|
|
25
|
+
sparse_features: list[Any]
|
|
26
|
+
sequence_features: list[Any]
|
|
27
|
+
task: TaskTypeName | list[TaskTypeName]
|
|
28
|
+
target_columns: list[str]
|
|
29
|
+
nums_task: int
|
|
30
|
+
metrics: Any
|
|
31
|
+
device: Any
|
|
32
|
+
optimizer_name: str
|
|
33
|
+
optimizer_params: dict[str, Any]
|
|
34
|
+
scheduler_name: str | None
|
|
35
|
+
scheduler_params: dict[str, Any]
|
|
36
|
+
loss_config: Any
|
|
37
|
+
loss_weights: Any
|
|
38
|
+
grad_norm: Any
|
|
39
|
+
embedding_l1_reg: float
|
|
40
|
+
embedding_l2_reg: float
|
|
41
|
+
dense_l1_reg: float
|
|
42
|
+
dense_l2_reg: float
|
|
43
|
+
early_stop_patience: int
|
|
44
|
+
max_gradient_norm: float | None
|
|
45
|
+
metrics_sample_limit: int | None
|
|
46
|
+
session_id: str | None
|
|
47
|
+
features_config_path: str
|
|
48
|
+
checkpoint_path: str
|
|
49
|
+
train_data_summary: dict[str, Any] | None
|
|
50
|
+
valid_data_summary: dict[str, Any] | None
|
|
51
|
+
|
|
18
52
|
def build_data_summary(
|
|
19
53
|
self, data: Any, data_loader: DataLoader | None, sample_key: str
|
|
20
54
|
):
|
|
@@ -305,7 +339,7 @@ class SummarySet:
|
|
|
305
339
|
lines = details.get("lines", [])
|
|
306
340
|
logger.info(f"{target_name}:")
|
|
307
341
|
for label, value in lines:
|
|
308
|
-
logger.info(format_kv(label, value))
|
|
342
|
+
logger.info(f" {format_kv(label, value)}")
|
|
309
343
|
|
|
310
344
|
if self.valid_data_summary:
|
|
311
345
|
if self.train_data_summary:
|
|
@@ -320,4 +354,4 @@ class SummarySet:
|
|
|
320
354
|
lines = details.get("lines", [])
|
|
321
355
|
logger.info(f"{target_name}:")
|
|
322
356
|
for label, value in lines:
|
|
323
|
-
logger.info(format_kv(label, value))
|
|
357
|
+
logger.info(f" {format_kv(label, value)}")
|