nextrec 0.4.21__tar.gz → 0.4.23__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nextrec-0.4.21 → nextrec-0.4.23}/PKG-INFO +8 -6
- {nextrec-0.4.21 → nextrec-0.4.23}/README.md +7 -5
- {nextrec-0.4.21 → nextrec-0.4.23}/README_en.md +5 -5
- {nextrec-0.4.21 → nextrec-0.4.23}/docs/en/Getting started guide.md +1 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/docs/rtd/conf.py +1 -1
- {nextrec-0.4.21 → nextrec-0.4.23}/docs/rtd/index.md +1 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/docs/zh//345/277/253/351/200/237/344/270/212/346/211/213.md +1 -0
- nextrec-0.4.23/nextrec/__version__.py +1 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/basic/activation.py +1 -1
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/basic/heads.py +2 -3
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/basic/metrics.py +1 -2
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/basic/model.py +115 -80
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/basic/summary.py +36 -2
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/data/preprocessor.py +137 -5
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/loss/__init__.py +0 -4
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/loss/grad_norm.py +3 -3
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/loss/listwise.py +19 -6
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/loss/pairwise.py +6 -4
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/loss/pointwise.py +8 -6
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/multi_task/esmm.py +3 -26
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/multi_task/mmoe.py +2 -24
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/multi_task/ple.py +13 -35
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/multi_task/poso.py +4 -28
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/multi_task/share_bottom.py +1 -24
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/afm.py +3 -27
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/autoint.py +5 -38
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/dcn.py +1 -26
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/dcn_v2.py +5 -33
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/deepfm.py +2 -29
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/dien.py +2 -28
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/din.py +2 -27
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/eulernet.py +3 -30
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/ffm.py +0 -26
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/fibinet.py +8 -32
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/fm.py +0 -29
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/lr.py +0 -30
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/masknet.py +4 -30
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/pnn.py +4 -28
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/widedeep.py +0 -32
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/xdeepfm.py +0 -30
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/retrieval/dssm.py +0 -24
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/retrieval/dssm_v2.py +0 -24
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/retrieval/mind.py +0 -20
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/retrieval/sdm.py +0 -20
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/retrieval/youtube_dnn.py +0 -21
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/sequential/hstu.py +0 -18
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/utils/__init__.py +5 -1
- nextrec-0.4.21/nextrec/loss/loss_utils.py → nextrec-0.4.23/nextrec/utils/loss.py +17 -7
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/utils/model.py +79 -1
- nextrec-0.4.23/nextrec/utils/types.py +98 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/pyproject.toml +1 -1
- {nextrec-0.4.21 → nextrec-0.4.23}/test/test_losses.py +54 -1
- {nextrec-0.4.21 → nextrec-0.4.23}/test/test_ranking_models.py +2 -3
- {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/example_multitask.py +1 -8
- {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/example_ranking_din.py +3 -5
- {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/notebooks/en/Hands on nextrec.ipynb +1 -1
- {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/notebooks/zh//345/277/253/351/200/237/345/205/245/351/227/250nextrec.ipynb +1 -1
- nextrec-0.4.21/nextrec/__version__.py +0 -1
- nextrec-0.4.21/nextrec/utils/types.py +0 -59
- {nextrec-0.4.21 → nextrec-0.4.23}/.github/workflows/publish.yml +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/.github/workflows/tests.yml +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/.gitignore +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/.readthedocs.yaml +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/CODE_OF_CONDUCT.md +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/CONTRIBUTING.md +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/LICENSE +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/MANIFEST.in +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/assets/Feature Configuration.png +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/assets/Model Parameters.png +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/assets/Training Configuration.png +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/assets/Training logs.png +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/assets/logo.png +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/assets/mmoe_tutorial.png +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/assets/nextrec_diagram.png +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/assets/test data.png +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/dataset/ctcvr_task.csv +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/dataset/ecommerce_task.csv +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/dataset/match_task.csv +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/dataset/movielens_100k.csv +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/dataset/multitask_task.csv +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/dataset/ranking_task.csv +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/docs/rtd/Makefile +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/docs/rtd/make.bat +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/docs/rtd/modules.rst +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/docs/rtd/nextrec.basic.rst +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/docs/rtd/nextrec.data.rst +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/docs/rtd/nextrec.loss.rst +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/docs/rtd/nextrec.rst +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/docs/rtd/nextrec.utils.rst +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/docs/rtd/requirements.txt +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/__init__.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/basic/__init__.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/basic/callback.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/basic/features.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/basic/layers.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/basic/loggers.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/basic/session.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/cli.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/data/__init__.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/data/batch_utils.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/data/data_processing.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/data/data_utils.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/data/dataloader.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/generative/__init__.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/generative/tiger.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/multi_task/__init__.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/ranking/__init__.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/representation/__init__.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/representation/autorec.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/representation/bpr.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/representation/cl4srec.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/representation/lightgcn.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/representation/mf.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/representation/rqvae.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/representation/s3rec.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/retrieval/__init__.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/models/sequential/sasrec.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/utils/config.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/utils/console.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/utils/data.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/utils/embedding.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/utils/feature.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec/utils/torch_utils.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/NextRec-CLI.md +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/NextRec-CLI_zh.md +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/feature_config.yaml +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/afm.yaml +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/autoint.yaml +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/dcn.yaml +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/deepfm.yaml +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/din.yaml +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/esmm.yaml +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/fibinet.yaml +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/fm.yaml +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/masknet.yaml +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/mmoe.yaml +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/ple.yaml +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/pnn.yaml +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/poso.yaml +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/share_bottom.yaml +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/widedeep.yaml +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/model_configs/xdeepfm.yaml +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/predict_config.yaml +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/predict_config_template.yaml +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/train_config.yaml +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/nextrec_cli_preset/train_config_template.yaml +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/pytest.ini +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/requirements.txt +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/scripts/format_code.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/test/__init__.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/test/conftest.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/test/helpers.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/test/run_tests.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/test/test_base_model_regularization.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/test/test_generative_models.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/test/test_layers.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/test/test_match_models.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/test/test_multitask_models.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/test/test_preprocessor.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/test/test_utils_console.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/test/test_utils_data.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/test/test_utils_embedding.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/test_requirements.txt +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/distributed/example_distributed_training.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/distributed/example_distributed_training_large_dataset.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/example_match.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/movielen_match_dssm.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/movielen_ranking_deepfm.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/notebooks/en/Build semantic ID with RQ-VAE.ipynb +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/notebooks/en/Hands on dataprocessor.ipynb +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/notebooks/zh//344/275/277/347/224/250RQ-VAE/346/236/204/345/273/272/350/257/255/344/271/211ID.ipynb" +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/notebooks/zh//345/246/202/344/275/225/344/275/277/347/224/250DataProcessor/350/277/233/350/241/214/351/242/204/345/244/204/347/220/206.ipynb" +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/run_all_match_models.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/run_all_multitask_models.py +0 -0
- {nextrec-0.4.21 → nextrec-0.4.23}/tutorials/run_all_ranking_models.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nextrec
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.23
|
|
4
4
|
Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
|
|
5
5
|
Project-URL: Homepage, https://github.com/zerolovesea/NextRec
|
|
6
6
|
Project-URL: Repository, https://github.com/zerolovesea/NextRec
|
|
@@ -65,11 +65,11 @@ Description-Content-Type: text/markdown
|
|
|
65
65
|
|
|
66
66
|
<div align="center">
|
|
67
67
|
|
|
68
|
-
[](https://pypistats.org/packages/nextrec)
|
|
69
69
|

|
|
70
70
|

|
|
71
71
|

|
|
72
|
-

|
|
73
73
|
[](https://deepwiki.com/zerolovesea/NextRec)
|
|
74
74
|
|
|
75
75
|
中文文档 | [English Version](README_en.md)
|
|
@@ -191,6 +191,8 @@ model = DIN(
|
|
|
191
191
|
dense_features=dense_features,
|
|
192
192
|
sparse_features=sparse_features,
|
|
193
193
|
sequence_features=sequence_features,
|
|
194
|
+
behavior_feature_name="sequence_0",
|
|
195
|
+
candidate_feature_name="item_id",
|
|
194
196
|
mlp_params=mlp_params,
|
|
195
197
|
attention_hidden_units=[80, 40],
|
|
196
198
|
attention_activation='sigmoid',
|
|
@@ -204,7 +206,7 @@ model = DIN(
|
|
|
204
206
|
session_id="din_tutorial", # 实验id,用于存放训练日志
|
|
205
207
|
)
|
|
206
208
|
|
|
207
|
-
#
|
|
209
|
+
# 编译模型,优化器/损失/学习率调度器统一在 compile 中设置
|
|
208
210
|
model.compile(
|
|
209
211
|
optimizer = "adam",
|
|
210
212
|
optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5},
|
|
@@ -247,11 +249,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
|
|
|
247
249
|
|
|
248
250
|
预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
|
|
249
251
|
|
|
250
|
-
> 截止当前版本0.4.
|
|
252
|
+
> 截止当前版本0.4.23,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
|
|
251
253
|
|
|
252
254
|
## 兼容平台
|
|
253
255
|
|
|
254
|
-
当前最新版本为0.4.
|
|
256
|
+
当前最新版本为0.4.23,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
|
|
255
257
|
|
|
256
258
|
| 平台 | 配置 |
|
|
257
259
|
|------|------|
|
|
@@ -4,11 +4,11 @@
|
|
|
4
4
|
|
|
5
5
|
<div align="center">
|
|
6
6
|
|
|
7
|
-
[](https://pypistats.org/packages/nextrec)
|
|
8
8
|

|
|
9
9
|

|
|
10
10
|

|
|
11
|
-

|
|
12
12
|
[](https://deepwiki.com/zerolovesea/NextRec)
|
|
13
13
|
|
|
14
14
|
中文文档 | [English Version](README_en.md)
|
|
@@ -130,6 +130,8 @@ model = DIN(
|
|
|
130
130
|
dense_features=dense_features,
|
|
131
131
|
sparse_features=sparse_features,
|
|
132
132
|
sequence_features=sequence_features,
|
|
133
|
+
behavior_feature_name="sequence_0",
|
|
134
|
+
candidate_feature_name="item_id",
|
|
133
135
|
mlp_params=mlp_params,
|
|
134
136
|
attention_hidden_units=[80, 40],
|
|
135
137
|
attention_activation='sigmoid',
|
|
@@ -143,7 +145,7 @@ model = DIN(
|
|
|
143
145
|
session_id="din_tutorial", # 实验id,用于存放训练日志
|
|
144
146
|
)
|
|
145
147
|
|
|
146
|
-
#
|
|
148
|
+
# 编译模型,优化器/损失/学习率调度器统一在 compile 中设置
|
|
147
149
|
model.compile(
|
|
148
150
|
optimizer = "adam",
|
|
149
151
|
optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5},
|
|
@@ -186,11 +188,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
|
|
|
186
188
|
|
|
187
189
|
预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
|
|
188
190
|
|
|
189
|
-
> 截止当前版本0.4.
|
|
191
|
+
> 截止当前版本0.4.23,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
|
|
190
192
|
|
|
191
193
|
## 兼容平台
|
|
192
194
|
|
|
193
|
-
当前最新版本为0.4.
|
|
195
|
+
当前最新版本为0.4.23,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
|
|
194
196
|
|
|
195
197
|
| 平台 | 配置 |
|
|
196
198
|
|------|------|
|
|
@@ -4,11 +4,11 @@
|
|
|
4
4
|
|
|
5
5
|
<div align="center">
|
|
6
6
|
|
|
7
|
-
[](https://pypistats.org/packages/nextrec)
|
|
8
8
|

|
|
9
9
|

|
|
10
10
|

|
|
11
|
-

|
|
12
12
|
[](https://deepwiki.com/zerolovesea/NextRec)
|
|
13
13
|
|
|
14
14
|
English | [中文文档](README.md)
|
|
@@ -148,7 +148,7 @@ model = DIN(
|
|
|
148
148
|
session_id="din_tutorial", # experiment id for logs
|
|
149
149
|
)
|
|
150
150
|
|
|
151
|
-
# Compile model
|
|
151
|
+
# Compile model; configure optimizer/loss/scheduler via compile()
|
|
152
152
|
model.compile(
|
|
153
153
|
optimizer = "adam",
|
|
154
154
|
optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5},
|
|
@@ -191,11 +191,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
|
|
|
191
191
|
|
|
192
192
|
Prediction outputs are saved under `{checkpoint_path}/predictions/{name}.{save_data_format}`.
|
|
193
193
|
|
|
194
|
-
> As of version 0.4.
|
|
194
|
+
> As of version 0.4.23, NextRec CLI supports single-machine training; distributed training features are currently under development.
|
|
195
195
|
|
|
196
196
|
## Platform Compatibility
|
|
197
197
|
|
|
198
|
-
The current version is 0.4.
|
|
198
|
+
The current version is 0.4.23. All models and test code have been validated on the following platforms. If you encounter compatibility issues, please report them in the issue tracker with your system version:
|
|
199
199
|
|
|
200
200
|
| Platform | Configuration |
|
|
201
201
|
|----------|---------------|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.4.23"
|
|
@@ -15,6 +15,7 @@ import torch.nn as nn
|
|
|
15
15
|
import torch.nn.functional as F
|
|
16
16
|
|
|
17
17
|
from nextrec.basic.layers import PredictionLayer
|
|
18
|
+
from nextrec.utils.types import TaskTypeName
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
class TaskHead(nn.Module):
|
|
@@ -27,9 +28,7 @@ class TaskHead(nn.Module):
|
|
|
27
28
|
|
|
28
29
|
def __init__(
|
|
29
30
|
self,
|
|
30
|
-
task_type:
|
|
31
|
-
Literal["binary", "regression"] | list[Literal["binary", "regression"]]
|
|
32
|
-
) = "binary",
|
|
31
|
+
task_type: TaskTypeName | list[TaskTypeName] = "binary",
|
|
33
32
|
task_dims: int | list[int] | None = None,
|
|
34
33
|
use_bias: bool = True,
|
|
35
34
|
return_logits: bool = False,
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Metrics computation and configuration for model evaluation.
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 29/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -39,7 +39,6 @@ REGRESSION_METRICS = {"mse", "mae", "rmse", "r2", "mape", "msle"}
|
|
|
39
39
|
TASK_DEFAULT_METRICS = {
|
|
40
40
|
"binary": ["auc", "gauc", "ks", "logloss", "accuracy", "precision", "recall", "f1"],
|
|
41
41
|
"regression": ["mse", "mae", "rmse", "r2", "mape"],
|
|
42
|
-
"multilabel": ["auc", "hamming_loss", "subset_accuracy", "micro_f1", "macro_f1"],
|
|
43
42
|
"matching": ["auc", "gauc", "precision@10", "hitrate@10", "map@10", "cosine"]
|
|
44
43
|
+ [f"recall@{k}" for k in (5, 10, 20)]
|
|
45
44
|
+ [f"ndcg@{k}" for k in (5, 10, 20)]
|
|
@@ -2,13 +2,14 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 29/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
import getpass
|
|
10
10
|
import logging
|
|
11
11
|
import os
|
|
12
|
+
import sys
|
|
12
13
|
import pickle
|
|
13
14
|
import socket
|
|
14
15
|
from pathlib import Path
|
|
@@ -16,6 +17,16 @@ from typing import Any, Literal
|
|
|
16
17
|
|
|
17
18
|
import numpy as np
|
|
18
19
|
import pandas as pd
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
import swanlab # type: ignore
|
|
23
|
+
except ModuleNotFoundError:
|
|
24
|
+
swanlab = None
|
|
25
|
+
try:
|
|
26
|
+
import wandb # type: ignore
|
|
27
|
+
except ModuleNotFoundError:
|
|
28
|
+
wandb = None
|
|
29
|
+
|
|
19
30
|
import torch
|
|
20
31
|
import torch.distributed as dist
|
|
21
32
|
import torch.nn as nn
|
|
@@ -60,8 +71,8 @@ from nextrec.loss import (
|
|
|
60
71
|
InfoNCELoss,
|
|
61
72
|
SampledSoftmaxLoss,
|
|
62
73
|
TripletLoss,
|
|
63
|
-
get_loss_fn,
|
|
64
74
|
)
|
|
75
|
+
from nextrec.utils.loss import get_loss_fn
|
|
65
76
|
from nextrec.loss.grad_norm import get_grad_norm_shared_params
|
|
66
77
|
from nextrec.utils.console import display_metrics_table, progress
|
|
67
78
|
from nextrec.utils.torch_utils import (
|
|
@@ -74,8 +85,20 @@ from nextrec.utils.torch_utils import (
|
|
|
74
85
|
to_tensor,
|
|
75
86
|
)
|
|
76
87
|
from nextrec.utils.config import safe_value
|
|
77
|
-
from nextrec.utils.model import
|
|
78
|
-
|
|
88
|
+
from nextrec.utils.model import (
|
|
89
|
+
compute_ranking_loss,
|
|
90
|
+
get_loss_list,
|
|
91
|
+
resolve_loss_weights,
|
|
92
|
+
get_training_modes,
|
|
93
|
+
)
|
|
94
|
+
from nextrec.utils.types import (
|
|
95
|
+
LossName,
|
|
96
|
+
OptimizerName,
|
|
97
|
+
SchedulerName,
|
|
98
|
+
TrainingModeName,
|
|
99
|
+
TaskTypeName,
|
|
100
|
+
MetricsName,
|
|
101
|
+
)
|
|
79
102
|
|
|
80
103
|
|
|
81
104
|
class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
@@ -84,7 +107,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
84
107
|
raise NotImplementedError
|
|
85
108
|
|
|
86
109
|
@property
|
|
87
|
-
def default_task(self) ->
|
|
110
|
+
def default_task(self) -> TaskTypeName | list[TaskTypeName]:
|
|
88
111
|
raise NotImplementedError
|
|
89
112
|
|
|
90
113
|
def __init__(
|
|
@@ -94,11 +117,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
94
117
|
sequence_features: list[SequenceFeature] | None = None,
|
|
95
118
|
target: list[str] | str | None = None,
|
|
96
119
|
id_columns: list[str] | str | None = None,
|
|
97
|
-
task:
|
|
98
|
-
training_mode:
|
|
99
|
-
Literal["pointwise", "pairwise", "listwise"]
|
|
100
|
-
| list[Literal["pointwise", "pairwise", "listwise"]]
|
|
101
|
-
) = "pointwise",
|
|
120
|
+
task: TaskTypeName | list[TaskTypeName] | None = None,
|
|
121
|
+
training_mode: TrainingModeName | list[TrainingModeName] = "pointwise",
|
|
102
122
|
embedding_l1_reg: float = 0.0,
|
|
103
123
|
dense_l1_reg: float = 0.0,
|
|
104
124
|
embedding_l2_reg: float = 0.0,
|
|
@@ -136,6 +156,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
136
156
|
world_size: Number of processes (defaults to env WORLD_SIZE).
|
|
137
157
|
local_rank: Local rank for selecting CUDA device (defaults to env LOCAL_RANK).
|
|
138
158
|
ddp_find_unused_parameters: Default False, set it True only when exist unused parameters in ddp model, in most cases should be False.
|
|
159
|
+
|
|
160
|
+
Note:
|
|
161
|
+
Optimizer, scheduler, and loss are configured via compile().
|
|
139
162
|
"""
|
|
140
163
|
super(BaseModel, self).__init__()
|
|
141
164
|
|
|
@@ -168,25 +191,12 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
168
191
|
dense_features, sparse_features, sequence_features, target, id_columns
|
|
169
192
|
)
|
|
170
193
|
|
|
171
|
-
self.task = self.default_task
|
|
194
|
+
self.task = task or self.default_task
|
|
172
195
|
self.nums_task = len(self.task) if isinstance(self.task, list) else 1
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
if
|
|
176
|
-
|
|
177
|
-
"[BaseModel-init Error] training_mode list length must match number of tasks."
|
|
178
|
-
)
|
|
179
|
-
else:
|
|
180
|
-
training_modes = [training_mode] * self.nums_task
|
|
181
|
-
if any(
|
|
182
|
-
mode not in {"pointwise", "pairwise", "listwise"}
|
|
183
|
-
for mode in training_modes
|
|
184
|
-
):
|
|
185
|
-
raise ValueError(
|
|
186
|
-
"[BaseModel-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
|
|
187
|
-
)
|
|
188
|
-
self.training_modes = training_modes
|
|
189
|
-
self.training_mode = training_modes if self.nums_task > 1 else training_modes[0]
|
|
196
|
+
self.training_modes = get_training_modes(training_mode, self.nums_task)
|
|
197
|
+
self.training_mode = (
|
|
198
|
+
self.training_modes if self.nums_task > 1 else self.training_modes[0]
|
|
199
|
+
)
|
|
190
200
|
|
|
191
201
|
self.embedding_l1_reg = embedding_l1_reg
|
|
192
202
|
self.dense_l1_reg = dense_l1_reg
|
|
@@ -194,7 +204,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
194
204
|
self.dense_l2_reg = dense_l2_reg
|
|
195
205
|
self.regularization_weights = []
|
|
196
206
|
self.embedding_params = []
|
|
197
|
-
|
|
207
|
+
|
|
208
|
+
self.ignore_label = None
|
|
209
|
+
self.compiled = False
|
|
198
210
|
|
|
199
211
|
self.max_gradient_norm = 1.0
|
|
200
212
|
self.logger_initialized = False
|
|
@@ -407,6 +419,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
407
419
|
loss: LossName | nn.Module | list[LossName | nn.Module] | None = "bce",
|
|
408
420
|
loss_params: dict | list[dict] | None = None,
|
|
409
421
|
loss_weights: int | float | list[int | float] | dict | str | None = None,
|
|
422
|
+
ignore_label: int | float | None = -1,
|
|
410
423
|
):
|
|
411
424
|
"""
|
|
412
425
|
Configure the model for training.
|
|
@@ -419,34 +432,17 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
419
432
|
loss_params: Loss function parameters, or list for multi-task. e.g., {'weight': tensor([0.25, 0.75])}.
|
|
420
433
|
loss_weights: Weights for each task loss, int/float for single-task or list for multi-task. e.g., 1.0, or [1.0, 0.5].
|
|
421
434
|
Use "grad_norm" or {"method": "grad_norm", ...} to enable GradNorm for multi-task loss balancing.
|
|
435
|
+
ignore_label: Label value to ignore when computing loss. Use this to skip gradients for unknown labels.
|
|
422
436
|
"""
|
|
437
|
+
self.ignore_label = ignore_label
|
|
423
438
|
default_losses = {
|
|
424
439
|
"pointwise": "bce",
|
|
425
440
|
"pairwise": "bpr",
|
|
426
441
|
"listwise": "listnet",
|
|
427
442
|
}
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
elif isinstance(effective_loss, list):
|
|
432
|
-
if not effective_loss:
|
|
433
|
-
loss_list = [default_losses[mode] for mode in self.training_modes]
|
|
434
|
-
else:
|
|
435
|
-
if len(effective_loss) != self.nums_task:
|
|
436
|
-
raise ValueError(
|
|
437
|
-
f"[BaseModel-compile Error] Number of loss functions ({len(effective_loss)}) must match number of tasks ({self.nums_task})."
|
|
438
|
-
)
|
|
439
|
-
loss_list = list(effective_loss)
|
|
440
|
-
else:
|
|
441
|
-
loss_list = [effective_loss] * self.nums_task
|
|
442
|
-
|
|
443
|
-
for idx, mode in enumerate(self.training_modes):
|
|
444
|
-
if isinstance(loss_list[idx], str) and loss_list[idx] in {
|
|
445
|
-
"bce",
|
|
446
|
-
"binary_crossentropy",
|
|
447
|
-
}:
|
|
448
|
-
if mode in {"pairwise", "listwise"}:
|
|
449
|
-
loss_list[idx] = default_losses[mode]
|
|
443
|
+
loss_list = get_loss_list(
|
|
444
|
+
loss, self.training_modes, self.nums_task, default_losses
|
|
445
|
+
)
|
|
450
446
|
self.loss_params = loss_params or {}
|
|
451
447
|
optimizer_params = optimizer_params or {}
|
|
452
448
|
self.optimizer_name = (
|
|
@@ -510,36 +506,16 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
510
506
|
nums_task=self.nums_task, device=self.device, **grad_norm_params
|
|
511
507
|
)
|
|
512
508
|
self.loss_weights = None
|
|
513
|
-
elif loss_weights is None:
|
|
514
|
-
self.loss_weights = None
|
|
515
|
-
elif self.nums_task == 1:
|
|
516
|
-
if isinstance(loss_weights, (list, tuple)):
|
|
517
|
-
if len(loss_weights) != 1:
|
|
518
|
-
raise ValueError(
|
|
519
|
-
"[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup."
|
|
520
|
-
)
|
|
521
|
-
loss_weights = loss_weights[0]
|
|
522
|
-
self.loss_weights = [float(loss_weights)] # type: ignore
|
|
523
509
|
else:
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
elif isinstance(loss_weights, (list, tuple)):
|
|
527
|
-
weights = [float(w) for w in loss_weights]
|
|
528
|
-
if len(weights) != self.nums_task:
|
|
529
|
-
raise ValueError(
|
|
530
|
-
f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task})."
|
|
531
|
-
)
|
|
532
|
-
else:
|
|
533
|
-
raise TypeError(
|
|
534
|
-
f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}"
|
|
535
|
-
)
|
|
536
|
-
self.loss_weights = weights
|
|
510
|
+
self.loss_weights = resolve_loss_weights(loss_weights, self.nums_task)
|
|
511
|
+
self.compiled = True
|
|
537
512
|
|
|
538
513
|
def compute_loss(self, y_pred, y_true):
|
|
539
514
|
if y_true is None:
|
|
540
515
|
raise ValueError(
|
|
541
516
|
"[BaseModel-compute_loss Error] Ground truth labels (y_true) are required."
|
|
542
517
|
)
|
|
518
|
+
|
|
543
519
|
# single-task
|
|
544
520
|
if self.nums_task == 1:
|
|
545
521
|
if y_pred.dim() == 1:
|
|
@@ -547,13 +523,24 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
547
523
|
if y_true.dim() == 1:
|
|
548
524
|
y_true = y_true.view(-1, 1)
|
|
549
525
|
if y_pred.shape != y_true.shape:
|
|
550
|
-
raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
|
|
551
|
-
loss_fn = self.loss_fn[0] if getattr(self, "loss_fn", None) else None
|
|
552
|
-
if loss_fn is None:
|
|
553
526
|
raise ValueError(
|
|
554
|
-
"[BaseModel-compute_loss Error]
|
|
527
|
+
f"[BaseModel-compute_loss Error] Shape mismatch: {y_pred.shape} vs {y_true.shape}"
|
|
555
528
|
)
|
|
529
|
+
|
|
530
|
+
loss_fn = self.loss_fn[0]
|
|
531
|
+
|
|
532
|
+
if self.ignore_label is not None:
|
|
533
|
+
valid_mask = y_true != self.ignore_label
|
|
534
|
+
if valid_mask.dim() > 1:
|
|
535
|
+
valid_mask = valid_mask.all(dim=1)
|
|
536
|
+
if not torch.any(valid_mask): # if no valid labels, return zero loss
|
|
537
|
+
return y_pred.sum() * 0.0
|
|
538
|
+
|
|
539
|
+
y_pred = y_pred[valid_mask]
|
|
540
|
+
y_true = y_true[valid_mask]
|
|
541
|
+
|
|
556
542
|
mode = self.training_modes[0]
|
|
543
|
+
|
|
557
544
|
task_dim = (
|
|
558
545
|
self.task_dims[0] if hasattr(self, "task_dims") else y_pred.shape[1] # type: ignore
|
|
559
546
|
)
|
|
@@ -584,7 +571,19 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
584
571
|
for i, (start, end) in enumerate(slices): # type: ignore
|
|
585
572
|
y_pred_i = y_pred[:, start:end]
|
|
586
573
|
y_true_i = y_true[:, start:end]
|
|
574
|
+
# mask ignored labels
|
|
575
|
+
if self.ignore_label is not None:
|
|
576
|
+
valid_mask = y_true_i != self.ignore_label
|
|
577
|
+
if valid_mask.dim() > 1:
|
|
578
|
+
valid_mask = valid_mask.all(dim=1)
|
|
579
|
+
if not torch.any(valid_mask):
|
|
580
|
+
task_losses.append(y_pred_i.sum() * 0.0)
|
|
581
|
+
continue
|
|
582
|
+
y_pred_i = y_pred_i[valid_mask]
|
|
583
|
+
y_true_i = y_true_i[valid_mask]
|
|
584
|
+
|
|
587
585
|
mode = self.training_modes[i]
|
|
586
|
+
|
|
588
587
|
if mode in {"pairwise", "listwise"}:
|
|
589
588
|
task_loss = compute_ranking_loss(
|
|
590
589
|
training_mode=mode,
|
|
@@ -594,7 +593,11 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
594
593
|
)
|
|
595
594
|
else:
|
|
596
595
|
task_loss = self.loss_fn[i](y_pred_i, y_true_i)
|
|
596
|
+
# task_loss = normalize_task_loss(
|
|
597
|
+
# task_loss, valid_count, total_count
|
|
598
|
+
# ) # normalize by valid samples to avoid loss scale issues
|
|
597
599
|
task_losses.append(task_loss)
|
|
600
|
+
|
|
598
601
|
if self.grad_norm is not None:
|
|
599
602
|
if self.grad_norm_shared_params is None:
|
|
600
603
|
self.grad_norm_shared_params = get_grad_norm_shared_params(
|
|
@@ -651,7 +654,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
651
654
|
train_data=None,
|
|
652
655
|
valid_data=None,
|
|
653
656
|
metrics: (
|
|
654
|
-
list[
|
|
657
|
+
list[MetricsName] | dict[str, list[MetricsName]] | None
|
|
655
658
|
) = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
|
|
656
659
|
epochs: int = 1,
|
|
657
660
|
shuffle: bool = True,
|
|
@@ -665,6 +668,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
665
668
|
use_tensorboard: bool = True,
|
|
666
669
|
use_wandb: bool = False,
|
|
667
670
|
use_swanlab: bool = False,
|
|
671
|
+
wandb_api: str | None = None,
|
|
672
|
+
swanlab_api: str | None = None,
|
|
668
673
|
wandb_kwargs: dict | None = None,
|
|
669
674
|
swanlab_kwargs: dict | None = None,
|
|
670
675
|
auto_ddp_sampler: bool = True,
|
|
@@ -694,6 +699,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
694
699
|
use_tensorboard: Enable tensorboard logging.
|
|
695
700
|
use_wandb: Enable Weights & Biases logging.
|
|
696
701
|
use_swanlab: Enable SwanLab logging.
|
|
702
|
+
wandb_api: W&B API key for non-tty login.
|
|
703
|
+
swanlab_api: SwanLab API key for non-tty login.
|
|
697
704
|
wandb_kwargs: Optional kwargs for wandb.init(...).
|
|
698
705
|
swanlab_kwargs: Optional kwargs for swanlab.init(...).
|
|
699
706
|
auto_ddp_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
|
|
@@ -711,6 +718,16 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
711
718
|
)
|
|
712
719
|
self.to(self.device)
|
|
713
720
|
|
|
721
|
+
if not self.compiled:
|
|
722
|
+
self.compile(
|
|
723
|
+
optimizer="adam",
|
|
724
|
+
optimizer_params={},
|
|
725
|
+
scheduler=None,
|
|
726
|
+
scheduler_params={},
|
|
727
|
+
loss=None,
|
|
728
|
+
loss_params={},
|
|
729
|
+
)
|
|
730
|
+
|
|
714
731
|
if (
|
|
715
732
|
self.distributed
|
|
716
733
|
and dist.is_available()
|
|
@@ -785,6 +802,24 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
785
802
|
}
|
|
786
803
|
training_config: dict = safe_value(training_config) # type: ignore
|
|
787
804
|
|
|
805
|
+
if self.is_main_process:
|
|
806
|
+
is_tty = sys.stdin.isatty() and sys.stdout.isatty()
|
|
807
|
+
if not is_tty:
|
|
808
|
+
if use_wandb and wandb_api:
|
|
809
|
+
if wandb is None:
|
|
810
|
+
logging.warning(
|
|
811
|
+
"[BaseModel-fit] wandb not installed, skip wandb login."
|
|
812
|
+
)
|
|
813
|
+
else:
|
|
814
|
+
wandb.login(key=wandb_api)
|
|
815
|
+
if use_swanlab and swanlab_api:
|
|
816
|
+
if swanlab is None:
|
|
817
|
+
logging.warning(
|
|
818
|
+
"[BaseModel-fit] swanlab not installed, skip swanlab login."
|
|
819
|
+
)
|
|
820
|
+
else:
|
|
821
|
+
swanlab.login(api_key=swanlab_api)
|
|
822
|
+
|
|
788
823
|
self.training_logger = (
|
|
789
824
|
TrainingLogger(
|
|
790
825
|
session=self.session,
|
|
@@ -2164,7 +2199,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2164
2199
|
scheduler_params: Parameters for the scheduler. e.g., {'step_size': 10, 'gamma': 0.1}.
|
|
2165
2200
|
loss: Loss function(s) to use (name, instance, or list). e.g., 'bce'.
|
|
2166
2201
|
loss_params: Parameters for the loss function(s). e.g., {'reduction': 'mean'}.
|
|
2167
|
-
loss_weights: Weights for the loss function(s). e.g., 1.0 or [0.7, 0.3].
|
|
2202
|
+
loss_weights: Weights for the loss function(s). e.g., 1.0 or [0.7, 0.3].
|
|
2168
2203
|
"""
|
|
2169
2204
|
if self.training_mode not in self.support_training_modes:
|
|
2170
2205
|
raise ValueError(
|
|
@@ -1,5 +1,9 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Summary utilities for BaseModel.
|
|
3
|
+
|
|
4
|
+
Date: create on 03/12/2025
|
|
5
|
+
Checkpoint: edit on 29/12/2025
|
|
6
|
+
Author: Yang Zhou,zyaztec@gmail.com
|
|
3
7
|
"""
|
|
4
8
|
|
|
5
9
|
from __future__ import annotations
|
|
@@ -12,9 +16,39 @@ from torch.utils.data import DataLoader
|
|
|
12
16
|
|
|
13
17
|
from nextrec.basic.loggers import colorize, format_kv
|
|
14
18
|
from nextrec.data.data_processing import extract_label_arrays, get_data_length
|
|
19
|
+
from nextrec.utils.types import TaskTypeName
|
|
15
20
|
|
|
16
21
|
|
|
17
22
|
class SummarySet:
|
|
23
|
+
model_name: str
|
|
24
|
+
dense_features: list[Any]
|
|
25
|
+
sparse_features: list[Any]
|
|
26
|
+
sequence_features: list[Any]
|
|
27
|
+
task: TaskTypeName | list[TaskTypeName]
|
|
28
|
+
target_columns: list[str]
|
|
29
|
+
nums_task: int
|
|
30
|
+
metrics: Any
|
|
31
|
+
device: Any
|
|
32
|
+
optimizer_name: str
|
|
33
|
+
optimizer_params: dict[str, Any]
|
|
34
|
+
scheduler_name: str | None
|
|
35
|
+
scheduler_params: dict[str, Any]
|
|
36
|
+
loss_config: Any
|
|
37
|
+
loss_weights: Any
|
|
38
|
+
grad_norm: Any
|
|
39
|
+
embedding_l1_reg: float
|
|
40
|
+
embedding_l2_reg: float
|
|
41
|
+
dense_l1_reg: float
|
|
42
|
+
dense_l2_reg: float
|
|
43
|
+
early_stop_patience: int
|
|
44
|
+
max_gradient_norm: float | None
|
|
45
|
+
metrics_sample_limit: int | None
|
|
46
|
+
session_id: str | None
|
|
47
|
+
features_config_path: str
|
|
48
|
+
checkpoint_path: str
|
|
49
|
+
train_data_summary: dict[str, Any] | None
|
|
50
|
+
valid_data_summary: dict[str, Any] | None
|
|
51
|
+
|
|
18
52
|
def build_data_summary(
|
|
19
53
|
self, data: Any, data_loader: DataLoader | None, sample_key: str
|
|
20
54
|
):
|
|
@@ -305,7 +339,7 @@ class SummarySet:
|
|
|
305
339
|
lines = details.get("lines", [])
|
|
306
340
|
logger.info(f"{target_name}:")
|
|
307
341
|
for label, value in lines:
|
|
308
|
-
logger.info(format_kv(label, value))
|
|
342
|
+
logger.info(f" {format_kv(label, value)}")
|
|
309
343
|
|
|
310
344
|
if self.valid_data_summary:
|
|
311
345
|
if self.train_data_summary:
|
|
@@ -320,4 +354,4 @@ class SummarySet:
|
|
|
320
354
|
lines = details.get("lines", [])
|
|
321
355
|
logger.info(f"{target_name}:")
|
|
322
356
|
for label, value in lines:
|
|
323
|
-
logger.info(format_kv(label, value))
|
|
357
|
+
logger.info(f" {format_kv(label, value)}")
|