nextrec 0.4.16__tar.gz → 0.4.17__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nextrec-0.4.16 → nextrec-0.4.17}/PKG-INFO +4 -4
- {nextrec-0.4.16 → nextrec-0.4.17}/README.md +3 -3
- {nextrec-0.4.16 → nextrec-0.4.17}/README_en.md +3 -3
- {nextrec-0.4.16 → nextrec-0.4.17}/docs/rtd/conf.py +1 -1
- nextrec-0.4.17/nextrec/__version__.py +1 -0
- nextrec-0.4.17/nextrec/basic/heads.py +101 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/basic/model.py +10 -9
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/multi_task/esmm.py +4 -3
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/multi_task/mmoe.py +4 -3
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/multi_task/ple.py +4 -3
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/multi_task/poso.py +4 -3
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/multi_task/share_bottom.py +4 -3
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/afm.py +4 -3
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/autoint.py +4 -3
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/dcn.py +4 -3
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/dcn_v2.py +4 -3
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/deepfm.py +4 -3
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/dien.py +2 -2
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/din.py +2 -2
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/eulernet.py +4 -3
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/ffm.py +4 -3
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/fibinet.py +2 -2
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/fm.py +4 -3
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/lr.py +4 -3
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/masknet.py +4 -3
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/pnn.py +4 -3
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/widedeep.py +4 -3
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/xdeepfm.py +4 -3
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/utils/console.py +19 -5
- {nextrec-0.4.16 → nextrec-0.4.17}/pyproject.toml +1 -1
- nextrec-0.4.16/nextrec/__version__.py +0 -1
- {nextrec-0.4.16 → nextrec-0.4.17}/.github/workflows/publish.yml +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/.github/workflows/tests.yml +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/.gitignore +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/.readthedocs.yaml +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/CODE_OF_CONDUCT.md +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/CONTRIBUTING.md +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/LICENSE +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/MANIFEST.in +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/assets/Feature Configuration.png +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/assets/Model Parameters.png +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/assets/Training Configuration.png +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/assets/Training logs.png +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/assets/logo.png +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/assets/mmoe_tutorial.png +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/assets/nextrec_diagram.png +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/assets/test data.png +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/dataset/ctcvr_task.csv +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/dataset/ecommerce_task.csv +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/dataset/match_task.csv +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/dataset/movielens_100k.csv +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/dataset/multitask_task.csv +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/dataset/ranking_task.csv +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/docs/en/Getting started guide.md +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/docs/rtd/Makefile +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/docs/rtd/index.md +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/docs/rtd/make.bat +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/docs/rtd/modules.rst +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/docs/rtd/nextrec.basic.rst +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/docs/rtd/nextrec.data.rst +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/docs/rtd/nextrec.loss.rst +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/docs/rtd/nextrec.rst +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/docs/rtd/nextrec.utils.rst +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/docs/rtd/requirements.txt +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/docs/zh//345/277/253/351/200/237/344/270/212/346/211/213.md" +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/__init__.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/basic/__init__.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/basic/activation.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/basic/callback.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/basic/features.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/basic/layers.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/basic/loggers.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/basic/metrics.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/basic/session.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/cli.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/data/__init__.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/data/batch_utils.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/data/data_processing.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/data/data_utils.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/data/dataloader.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/data/preprocessor.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/loss/__init__.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/loss/grad_norm.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/loss/listwise.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/loss/loss_utils.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/loss/pairwise.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/loss/pointwise.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/generative/__init__.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/generative/tiger.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/multi_task/__init__.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/ranking/__init__.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/representation/__init__.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/representation/autorec.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/representation/bpr.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/representation/cl4srec.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/representation/lightgcn.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/representation/mf.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/representation/rqvae.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/representation/s3rec.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/retrieval/__init__.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/retrieval/dssm.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/retrieval/dssm_v2.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/retrieval/mind.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/retrieval/sdm.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/retrieval/youtube_dnn.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/sequential/hstu.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/models/sequential/sasrec.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/utils/__init__.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/utils/config.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/utils/data.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/utils/embedding.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/utils/feature.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/utils/model.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec/utils/torch_utils.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/NextRec-CLI.md +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/NextRec-CLI_zh.md +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/feature_config.yaml +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/afm.yaml +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/autoint.yaml +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/dcn.yaml +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/deepfm.yaml +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/din.yaml +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/esmm.yaml +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/fibinet.yaml +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/fm.yaml +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/masknet.yaml +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/mmoe.yaml +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/ple.yaml +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/pnn.yaml +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/poso.yaml +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/share_bottom.yaml +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/widedeep.yaml +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/model_configs/xdeepfm.yaml +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/predict_config.yaml +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/predict_config_template.yaml +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/train_config.yaml +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/nextrec_cli_preset/train_config_template.yaml +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/pytest.ini +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/requirements.txt +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/scripts/format_code.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/test/__init__.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/test/conftest.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/test/helpers.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/test/run_tests.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/test/test_base_model_regularization.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/test/test_generative_models.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/test/test_layers.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/test/test_losses.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/test/test_match_models.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/test/test_multitask_models.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/test/test_preprocessor.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/test/test_ranking_models.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/test/test_utils_console.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/test/test_utils_data.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/test/test_utils_embedding.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/test_requirements.txt +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/distributed/example_distributed_training.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/distributed/example_distributed_training_large_dataset.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/example_multitask.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/example_ranking_din.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/movielen_match_dssm.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/movielen_ranking_deepfm.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/notebooks/en/Build semantic ID with RQ-VAE.ipynb +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/notebooks/en/Hands on dataprocessor.ipynb +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/notebooks/en/Hands on nextrec.ipynb +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/notebooks/zh//344/275/277/347/224/250RQ-VAE/346/236/204/345/273/272/350/257/255/344/271/211ID.ipynb" +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/notebooks/zh//345/246/202/344/275/225/344/275/277/347/224/250DataProcessor/350/277/233/350/241/214/351/242/204/345/244/204/347/220/206.ipynb" +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/notebooks/zh//345/277/253/351/200/237/345/205/245/351/227/250nextrec.ipynb" +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/run_all_match_models.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/run_all_multitask_models.py +0 -0
- {nextrec-0.4.16 → nextrec-0.4.17}/tutorials/run_all_ranking_models.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nextrec
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.17
|
|
4
4
|
Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
|
|
5
5
|
Project-URL: Homepage, https://github.com/zerolovesea/NextRec
|
|
6
6
|
Project-URL: Repository, https://github.com/zerolovesea/NextRec
|
|
@@ -68,7 +68,7 @@ Description-Content-Type: text/markdown
|
|
|
68
68
|

|
|
69
69
|
|
|
70
70
|

|
|
71
|
-

|
|
72
72
|
|
|
73
73
|
|
|
74
74
|
中文文档 | [English Version](README_en.md)
|
|
@@ -244,11 +244,11 @@ nextrec --mode=train --train_config=path/to/train_config.yaml
|
|
|
244
244
|
nextrec --mode=predict --predict_config=path/to/predict_config.yaml
|
|
245
245
|
```
|
|
246
246
|
|
|
247
|
-
> 截止当前版本0.4.
|
|
247
|
+
> 截止当前版本0.4.17,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
|
|
248
248
|
|
|
249
249
|
## 兼容平台
|
|
250
250
|
|
|
251
|
-
当前最新版本为0.4.
|
|
251
|
+
当前最新版本为0.4.17,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
|
|
252
252
|
|
|
253
253
|
| 平台 | 配置 |
|
|
254
254
|
|------|------|
|
|
@@ -9,7 +9,7 @@
|
|
|
9
9
|

|
|
10
10
|
|
|
11
11
|

|
|
12
|
-

|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
中文文档 | [English Version](README_en.md)
|
|
@@ -185,11 +185,11 @@ nextrec --mode=train --train_config=path/to/train_config.yaml
|
|
|
185
185
|
nextrec --mode=predict --predict_config=path/to/predict_config.yaml
|
|
186
186
|
```
|
|
187
187
|
|
|
188
|
-
> 截止当前版本0.4.
|
|
188
|
+
> 截止当前版本0.4.17,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
|
|
189
189
|
|
|
190
190
|
## 兼容平台
|
|
191
191
|
|
|
192
|
-
当前最新版本为0.4.
|
|
192
|
+
当前最新版本为0.4.17,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
|
|
193
193
|
|
|
194
194
|
| 平台 | 配置 |
|
|
195
195
|
|------|------|
|
|
@@ -9,7 +9,7 @@
|
|
|
9
9
|

|
|
10
10
|
|
|
11
11
|

|
|
12
|
-

|
|
13
13
|
|
|
14
14
|
English | [中文文档](README.md)
|
|
15
15
|
|
|
@@ -188,11 +188,11 @@ nextrec --mode=train --train_config=path/to/train_config.yaml
|
|
|
188
188
|
nextrec --mode=predict --predict_config=path/to/predict_config.yaml
|
|
189
189
|
```
|
|
190
190
|
|
|
191
|
-
> As of version 0.4.
|
|
191
|
+
> As of version 0.4.17, NextRec CLI supports single-machine training; distributed training features are currently under development.
|
|
192
192
|
|
|
193
193
|
## Platform Compatibility
|
|
194
194
|
|
|
195
|
-
The current version is 0.4.
|
|
195
|
+
The current version is 0.4.17. All models and test code have been validated on the following platforms. If you encounter compatibility issues, please report them in the issue tracker with your system version:
|
|
196
196
|
|
|
197
197
|
| Platform | Configuration |
|
|
198
198
|
|----------|---------------|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.4.17"
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Task head implementations for NextRec models.
|
|
3
|
+
|
|
4
|
+
Date: create on 23/12/2025
|
|
5
|
+
Author: Yang Zhou, zyaztec@gmail.com
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from typing import Literal
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
import torch.nn as nn
|
|
14
|
+
import torch.nn.functional as F
|
|
15
|
+
|
|
16
|
+
from nextrec.basic.layers import PredictionLayer
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TaskHead(nn.Module):
|
|
20
|
+
"""
|
|
21
|
+
Unified task head for ranking/regression/multi-task outputs.
|
|
22
|
+
|
|
23
|
+
This wraps PredictionLayer so models can depend on a "Head" abstraction
|
|
24
|
+
without changing their existing forward signatures.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
task_type: str | list[str] = "binary",
|
|
30
|
+
task_dims: int | list[int] | None = None,
|
|
31
|
+
use_bias: bool = True,
|
|
32
|
+
return_logits: bool = False,
|
|
33
|
+
) -> None:
|
|
34
|
+
super().__init__()
|
|
35
|
+
self.prediction = PredictionLayer(
|
|
36
|
+
task_type=task_type,
|
|
37
|
+
task_dims=task_dims,
|
|
38
|
+
use_bias=use_bias,
|
|
39
|
+
return_logits=return_logits,
|
|
40
|
+
)
|
|
41
|
+
# Expose commonly used attributes for compatibility with PredictionLayer.
|
|
42
|
+
self.task_types = self.prediction.task_types
|
|
43
|
+
self.task_dims = self.prediction.task_dims
|
|
44
|
+
self.task_slices = self.prediction.task_slices
|
|
45
|
+
self.total_dim = self.prediction.total_dim
|
|
46
|
+
|
|
47
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
48
|
+
return self.prediction(x)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class RetrievalHead(nn.Module):
|
|
52
|
+
"""
|
|
53
|
+
Retrieval head for two-tower models.
|
|
54
|
+
|
|
55
|
+
It computes similarity for pointwise training/inference, and returns
|
|
56
|
+
raw embeddings for in-batch negative sampling in pairwise/listwise modes.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
|
|
62
|
+
temperature: float = 1.0,
|
|
63
|
+
training_mode: Literal["pointwise", "pairwise", "listwise"] = "pointwise",
|
|
64
|
+
apply_sigmoid: bool = True,
|
|
65
|
+
) -> None:
|
|
66
|
+
super().__init__()
|
|
67
|
+
self.similarity_metric = similarity_metric
|
|
68
|
+
self.temperature = temperature
|
|
69
|
+
self.training_mode = training_mode
|
|
70
|
+
self.apply_sigmoid = apply_sigmoid
|
|
71
|
+
|
|
72
|
+
def forward(
|
|
73
|
+
self,
|
|
74
|
+
user_emb: torch.Tensor,
|
|
75
|
+
item_emb: torch.Tensor,
|
|
76
|
+
similarity_fn=None,
|
|
77
|
+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
|
78
|
+
if self.training and self.training_mode in {"pairwise", "listwise"}:
|
|
79
|
+
return user_emb, item_emb
|
|
80
|
+
|
|
81
|
+
if similarity_fn is not None:
|
|
82
|
+
similarity = similarity_fn(user_emb, item_emb)
|
|
83
|
+
else:
|
|
84
|
+
if user_emb.dim() == 2 and item_emb.dim() == 3:
|
|
85
|
+
user_emb = user_emb.unsqueeze(1)
|
|
86
|
+
|
|
87
|
+
if self.similarity_metric == "dot":
|
|
88
|
+
similarity = torch.sum(user_emb * item_emb, dim=-1)
|
|
89
|
+
elif self.similarity_metric == "cosine":
|
|
90
|
+
similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
|
|
91
|
+
elif self.similarity_metric == "euclidean":
|
|
92
|
+
similarity = -torch.sum((user_emb - item_emb) ** 2, dim=-1)
|
|
93
|
+
else:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"Unknown similarity metric: {self.similarity_metric}"
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
similarity = similarity / self.temperature
|
|
99
|
+
if self.training_mode == "pointwise" and self.apply_sigmoid:
|
|
100
|
+
return torch.sigmoid(similarity)
|
|
101
|
+
return similarity
|
|
@@ -38,6 +38,7 @@ from nextrec.basic.features import (
|
|
|
38
38
|
SequenceFeature,
|
|
39
39
|
SparseFeature,
|
|
40
40
|
)
|
|
41
|
+
from nextrec.basic.heads import RetrievalHead
|
|
41
42
|
from nextrec.basic.loggers import TrainingLogger, colorize, format_kv, setup_logger
|
|
42
43
|
from nextrec.basic.metrics import check_user_id, configure_metrics, evaluate_metrics
|
|
43
44
|
from nextrec.basic.session import create_session, resolve_save_path
|
|
@@ -2115,6 +2116,12 @@ class BaseMatchModel(BaseModel):
|
|
|
2115
2116
|
)
|
|
2116
2117
|
self.user_feature_names = {feature.name for feature in self.user_features_all}
|
|
2117
2118
|
self.item_feature_names = {feature.name for feature in self.item_features_all}
|
|
2119
|
+
self.head = RetrievalHead(
|
|
2120
|
+
similarity_metric=self.similarity_metric,
|
|
2121
|
+
temperature=self.temperature,
|
|
2122
|
+
training_mode=self.training_mode,
|
|
2123
|
+
apply_sigmoid=True,
|
|
2124
|
+
)
|
|
2118
2125
|
|
|
2119
2126
|
def compile(
|
|
2120
2127
|
self,
|
|
@@ -2244,15 +2251,9 @@ class BaseMatchModel(BaseModel):
|
|
|
2244
2251
|
user_emb = self.user_tower(user_input) # [B, D]
|
|
2245
2252
|
item_emb = self.item_tower(item_input) # [B, D]
|
|
2246
2253
|
|
|
2247
|
-
|
|
2248
|
-
|
|
2249
|
-
|
|
2250
|
-
similarity = self.compute_similarity(user_emb, item_emb) # [B]
|
|
2251
|
-
|
|
2252
|
-
if self.training_mode == "pointwise":
|
|
2253
|
-
return torch.sigmoid(similarity)
|
|
2254
|
-
else:
|
|
2255
|
-
return similarity
|
|
2254
|
+
return self.head(
|
|
2255
|
+
user_emb, item_emb, similarity_fn=self.compute_similarity
|
|
2256
|
+
)
|
|
2256
2257
|
|
|
2257
2258
|
def compute_loss(self, y_pred, y_true):
|
|
2258
2259
|
if self.training_mode == "pointwise":
|
|
@@ -45,7 +45,8 @@ import torch
|
|
|
45
45
|
import torch.nn as nn
|
|
46
46
|
|
|
47
47
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
48
|
-
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
48
|
+
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
49
|
+
from nextrec.basic.heads import TaskHead
|
|
49
50
|
from nextrec.basic.model import BaseModel
|
|
50
51
|
|
|
51
52
|
|
|
@@ -139,7 +140,7 @@ class ESMM(BaseModel):
|
|
|
139
140
|
# CVR tower
|
|
140
141
|
self.cvr_tower = MLP(input_dim=input_dim, output_layer=True, **cvr_params)
|
|
141
142
|
self.grad_norm_shared_modules = ["embedding"]
|
|
142
|
-
self.prediction_layer =
|
|
143
|
+
self.prediction_layer = TaskHead(
|
|
143
144
|
task_type=self.default_task, task_dims=[1, 1]
|
|
144
145
|
)
|
|
145
146
|
# Register regularization weights
|
|
@@ -167,4 +168,4 @@ class ESMM(BaseModel):
|
|
|
167
168
|
|
|
168
169
|
# Output: [CTR, CTCVR], We supervise CTR with click labels and CTCVR with conversion labels
|
|
169
170
|
y = torch.cat([ctr, ctcvr], dim=1) # [B, 2]
|
|
170
|
-
return y # [B, 2], where y[:, 0] is CTR and y[:, 1] is CTCVR
|
|
171
|
+
return y # [B, 2], where y[:, 0] is CTR and y[:, 1] is CTCVR
|
|
@@ -46,7 +46,8 @@ import torch
|
|
|
46
46
|
import torch.nn as nn
|
|
47
47
|
|
|
48
48
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
49
|
-
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
49
|
+
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
50
|
+
from nextrec.basic.heads import TaskHead
|
|
50
51
|
from nextrec.basic.model import BaseModel
|
|
51
52
|
|
|
52
53
|
|
|
@@ -172,7 +173,7 @@ class MMOE(BaseModel):
|
|
|
172
173
|
for tower_params in tower_params_list:
|
|
173
174
|
tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
|
|
174
175
|
self.towers.append(tower)
|
|
175
|
-
self.prediction_layer =
|
|
176
|
+
self.prediction_layer = TaskHead(
|
|
176
177
|
task_type=self.default_task, task_dims=[1] * self.num_tasks
|
|
177
178
|
)
|
|
178
179
|
# Register regularization weights
|
|
@@ -219,4 +220,4 @@ class MMOE(BaseModel):
|
|
|
219
220
|
|
|
220
221
|
# Stack outputs: [B, num_tasks]
|
|
221
222
|
y = torch.cat(task_outputs, dim=1)
|
|
222
|
-
return self.prediction_layer(y)
|
|
223
|
+
return self.prediction_layer(y)
|
|
@@ -49,7 +49,8 @@ import torch
|
|
|
49
49
|
import torch.nn as nn
|
|
50
50
|
|
|
51
51
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
52
|
-
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
52
|
+
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
53
|
+
from nextrec.basic.heads import TaskHead
|
|
53
54
|
from nextrec.basic.model import BaseModel
|
|
54
55
|
from nextrec.utils.model import get_mlp_output_dim
|
|
55
56
|
|
|
@@ -302,7 +303,7 @@ class PLE(BaseModel):
|
|
|
302
303
|
for tower_params in tower_params_list:
|
|
303
304
|
tower = MLP(input_dim=expert_output_dim, output_layer=True, **tower_params)
|
|
304
305
|
self.towers.append(tower)
|
|
305
|
-
self.prediction_layer =
|
|
306
|
+
self.prediction_layer = TaskHead(
|
|
306
307
|
task_type=self.default_task, task_dims=[1] * self.num_tasks
|
|
307
308
|
)
|
|
308
309
|
# Register regularization weights
|
|
@@ -336,4 +337,4 @@ class PLE(BaseModel):
|
|
|
336
337
|
|
|
337
338
|
# [B, num_tasks]
|
|
338
339
|
y = torch.cat(task_outputs, dim=1)
|
|
339
|
-
return self.prediction_layer(y)
|
|
340
|
+
return self.prediction_layer(y)
|
|
@@ -44,7 +44,8 @@ import torch.nn.functional as F
|
|
|
44
44
|
|
|
45
45
|
from nextrec.basic.activation import activation_layer
|
|
46
46
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
47
|
-
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
47
|
+
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
48
|
+
from nextrec.basic.heads import TaskHead
|
|
48
49
|
from nextrec.basic.model import BaseModel
|
|
49
50
|
from nextrec.utils.model import select_features
|
|
50
51
|
|
|
@@ -487,7 +488,7 @@ class POSO(BaseModel):
|
|
|
487
488
|
self.grad_norm_shared_modules = ["embedding"]
|
|
488
489
|
else:
|
|
489
490
|
self.grad_norm_shared_modules = ["embedding", "mmoe"]
|
|
490
|
-
self.prediction_layer =
|
|
491
|
+
self.prediction_layer = TaskHead(
|
|
491
492
|
task_type=self.default_task,
|
|
492
493
|
task_dims=[1] * self.num_tasks,
|
|
493
494
|
)
|
|
@@ -524,4 +525,4 @@ class POSO(BaseModel):
|
|
|
524
525
|
task_outputs.append(logit)
|
|
525
526
|
|
|
526
527
|
y = torch.cat(task_outputs, dim=1)
|
|
527
|
-
return self.prediction_layer(y)
|
|
528
|
+
return self.prediction_layer(y)
|
|
@@ -43,7 +43,8 @@ import torch
|
|
|
43
43
|
import torch.nn as nn
|
|
44
44
|
|
|
45
45
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
46
|
-
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
46
|
+
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
47
|
+
from nextrec.basic.heads import TaskHead
|
|
47
48
|
from nextrec.basic.model import BaseModel
|
|
48
49
|
|
|
49
50
|
|
|
@@ -142,7 +143,7 @@ class ShareBottom(BaseModel):
|
|
|
142
143
|
for tower_params in tower_params_list:
|
|
143
144
|
tower = MLP(input_dim=bottom_output_dim, output_layer=True, **tower_params)
|
|
144
145
|
self.towers.append(tower)
|
|
145
|
-
self.prediction_layer =
|
|
146
|
+
self.prediction_layer = TaskHead(
|
|
146
147
|
task_type=self.default_task, task_dims=[1] * self.num_tasks
|
|
147
148
|
)
|
|
148
149
|
# Register regularization weights
|
|
@@ -171,4 +172,4 @@ class ShareBottom(BaseModel):
|
|
|
171
172
|
|
|
172
173
|
# Stack outputs: [B, num_tasks]
|
|
173
174
|
y = torch.cat(task_outputs, dim=1)
|
|
174
|
-
return self.prediction_layer(y)
|
|
175
|
+
return self.prediction_layer(y)
|
|
@@ -40,7 +40,8 @@ import torch
|
|
|
40
40
|
import torch.nn as nn
|
|
41
41
|
|
|
42
42
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
43
|
-
from nextrec.basic.layers import EmbeddingLayer, InputMask
|
|
43
|
+
from nextrec.basic.layers import EmbeddingLayer, InputMask
|
|
44
|
+
from nextrec.basic.heads import TaskHead
|
|
44
45
|
from nextrec.basic.model import BaseModel
|
|
45
46
|
|
|
46
47
|
|
|
@@ -141,7 +142,7 @@ class AFM(BaseModel):
|
|
|
141
142
|
self.attention_p = nn.Linear(attention_dim, 1, bias=False)
|
|
142
143
|
self.attention_dropout = nn.Dropout(attention_dropout)
|
|
143
144
|
self.output_projection = nn.Linear(self.embedding_dim, 1, bias=False)
|
|
144
|
-
self.prediction_layer =
|
|
145
|
+
self.prediction_layer = TaskHead(task_type=self.default_task)
|
|
145
146
|
self.input_mask = InputMask()
|
|
146
147
|
|
|
147
148
|
# Register regularization weights
|
|
@@ -243,4 +244,4 @@ class AFM(BaseModel):
|
|
|
243
244
|
y_afm = self.output_projection(weighted_sum)
|
|
244
245
|
|
|
245
246
|
y = y_linear + y_afm
|
|
246
|
-
return self.prediction_layer(y)
|
|
247
|
+
return self.prediction_layer(y)
|
|
@@ -58,7 +58,8 @@ import torch
|
|
|
58
58
|
import torch.nn as nn
|
|
59
59
|
|
|
60
60
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
61
|
-
from nextrec.basic.layers import EmbeddingLayer, MultiHeadSelfAttention
|
|
61
|
+
from nextrec.basic.layers import EmbeddingLayer, MultiHeadSelfAttention
|
|
62
|
+
from nextrec.basic.heads import TaskHead
|
|
62
63
|
from nextrec.basic.model import BaseModel
|
|
63
64
|
|
|
64
65
|
|
|
@@ -162,7 +163,7 @@ class AutoInt(BaseModel):
|
|
|
162
163
|
|
|
163
164
|
# Final prediction layer
|
|
164
165
|
self.fc = nn.Linear(num_fields * att_embedding_dim, 1)
|
|
165
|
-
self.prediction_layer =
|
|
166
|
+
self.prediction_layer = TaskHead(task_type=self.default_task)
|
|
166
167
|
|
|
167
168
|
# Register regularization weights
|
|
168
169
|
self.register_regularization_weights(
|
|
@@ -206,4 +207,4 @@ class AutoInt(BaseModel):
|
|
|
206
207
|
start_dim=1
|
|
207
208
|
) # [B, num_fields * att_embedding_dim]
|
|
208
209
|
y = self.fc(attention_output_flat) # [B, 1]
|
|
209
|
-
return self.prediction_layer(y)
|
|
210
|
+
return self.prediction_layer(y)
|
|
@@ -54,7 +54,8 @@ import torch
|
|
|
54
54
|
import torch.nn as nn
|
|
55
55
|
|
|
56
56
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
57
|
-
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
57
|
+
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
58
|
+
from nextrec.basic.heads import TaskHead
|
|
58
59
|
from nextrec.basic.model import BaseModel
|
|
59
60
|
|
|
60
61
|
|
|
@@ -163,7 +164,7 @@ class DCN(BaseModel):
|
|
|
163
164
|
# Final layer only uses cross network output
|
|
164
165
|
self.final_layer = nn.Linear(input_dim, 1)
|
|
165
166
|
|
|
166
|
-
self.prediction_layer =
|
|
167
|
+
self.prediction_layer = TaskHead(task_type=self.task)
|
|
167
168
|
|
|
168
169
|
# Register regularization weights
|
|
169
170
|
self.register_regularization_weights(
|
|
@@ -197,4 +198,4 @@ class DCN(BaseModel):
|
|
|
197
198
|
|
|
198
199
|
# Final prediction
|
|
199
200
|
y = self.final_layer(combined)
|
|
200
|
-
return self.prediction_layer(y)
|
|
201
|
+
return self.prediction_layer(y)
|
|
@@ -47,7 +47,8 @@ import torch
|
|
|
47
47
|
import torch.nn as nn
|
|
48
48
|
|
|
49
49
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
50
|
-
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
50
|
+
from nextrec.basic.layers import MLP, EmbeddingLayer
|
|
51
|
+
from nextrec.basic.heads import TaskHead
|
|
51
52
|
from nextrec.basic.model import BaseModel
|
|
52
53
|
|
|
53
54
|
|
|
@@ -272,7 +273,7 @@ class DCNv2(BaseModel):
|
|
|
272
273
|
final_input_dim = input_dim
|
|
273
274
|
|
|
274
275
|
self.final_layer = nn.Linear(final_input_dim, 1)
|
|
275
|
-
self.prediction_layer =
|
|
276
|
+
self.prediction_layer = TaskHead(task_type=self.default_task)
|
|
276
277
|
|
|
277
278
|
self.register_regularization_weights(
|
|
278
279
|
embedding_attr="embedding",
|
|
@@ -301,4 +302,4 @@ class DCNv2(BaseModel):
|
|
|
301
302
|
combined = cross_out
|
|
302
303
|
|
|
303
304
|
logit = self.final_layer(combined)
|
|
304
|
-
return self.prediction_layer(logit)
|
|
305
|
+
return self.prediction_layer(logit)
|
|
@@ -45,7 +45,8 @@ embedding,无需手工构造交叉特征即可端到端训练,常用于 CTR/
|
|
|
45
45
|
import torch.nn as nn
|
|
46
46
|
|
|
47
47
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
48
|
-
from nextrec.basic.layers import FM, LR, MLP, EmbeddingLayer
|
|
48
|
+
from nextrec.basic.layers import FM, LR, MLP, EmbeddingLayer
|
|
49
|
+
from nextrec.basic.heads import TaskHead
|
|
49
50
|
from nextrec.basic.model import BaseModel
|
|
50
51
|
|
|
51
52
|
|
|
@@ -111,7 +112,7 @@ class DeepFM(BaseModel):
|
|
|
111
112
|
self.linear = LR(fm_emb_dim_total)
|
|
112
113
|
self.fm = FM(reduce_sum=True)
|
|
113
114
|
self.mlp = MLP(input_dim=mlp_input_dim, **mlp_params)
|
|
114
|
-
self.prediction_layer =
|
|
115
|
+
self.prediction_layer = TaskHead(task_type=self.default_task)
|
|
115
116
|
|
|
116
117
|
# Register regularization weights
|
|
117
118
|
self.register_regularization_weights(
|
|
@@ -133,4 +134,4 @@ class DeepFM(BaseModel):
|
|
|
133
134
|
y_deep = self.mlp(input_deep) # [B, 1]
|
|
134
135
|
|
|
135
136
|
y = y_linear + y_fm + y_deep
|
|
136
|
-
return self.prediction_layer(y)
|
|
137
|
+
return self.prediction_layer(y)
|
|
@@ -55,8 +55,8 @@ from nextrec.basic.layers import (
|
|
|
55
55
|
MLP,
|
|
56
56
|
AttentionPoolingLayer,
|
|
57
57
|
EmbeddingLayer,
|
|
58
|
-
PredictionLayer,
|
|
59
58
|
)
|
|
59
|
+
from nextrec.basic.heads import TaskHead
|
|
60
60
|
from nextrec.basic.model import BaseModel
|
|
61
61
|
|
|
62
62
|
|
|
@@ -346,7 +346,7 @@ class DIEN(BaseModel):
|
|
|
346
346
|
)
|
|
347
347
|
|
|
348
348
|
self.mlp = MLP(input_dim=mlp_input_dim, **mlp_params)
|
|
349
|
-
self.prediction_layer =
|
|
349
|
+
self.prediction_layer = TaskHead(task_type=self.task)
|
|
350
350
|
|
|
351
351
|
self.register_regularization_weights(
|
|
352
352
|
embedding_attr="embedding",
|
|
@@ -55,8 +55,8 @@ from nextrec.basic.layers import (
|
|
|
55
55
|
MLP,
|
|
56
56
|
AttentionPoolingLayer,
|
|
57
57
|
EmbeddingLayer,
|
|
58
|
-
PredictionLayer,
|
|
59
58
|
)
|
|
59
|
+
from nextrec.basic.heads import TaskHead
|
|
60
60
|
from nextrec.basic.model import BaseModel
|
|
61
61
|
|
|
62
62
|
|
|
@@ -173,7 +173,7 @@ class DIN(BaseModel):
|
|
|
173
173
|
|
|
174
174
|
# MLP for final prediction
|
|
175
175
|
self.mlp = MLP(input_dim=mlp_input_dim, **mlp_params)
|
|
176
|
-
self.prediction_layer =
|
|
176
|
+
self.prediction_layer = TaskHead(task_type=self.task)
|
|
177
177
|
|
|
178
178
|
# Register regularization weights
|
|
179
179
|
self.register_regularization_weights(
|
|
@@ -38,7 +38,8 @@ import torch.nn as nn
|
|
|
38
38
|
import torch.nn.functional as F
|
|
39
39
|
|
|
40
40
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
41
|
-
from nextrec.basic.layers import LR, EmbeddingLayer
|
|
41
|
+
from nextrec.basic.layers import LR, EmbeddingLayer
|
|
42
|
+
from nextrec.basic.heads import TaskHead
|
|
42
43
|
from nextrec.basic.model import BaseModel
|
|
43
44
|
|
|
44
45
|
|
|
@@ -295,7 +296,7 @@ class EulerNet(BaseModel):
|
|
|
295
296
|
else:
|
|
296
297
|
self.linear = None
|
|
297
298
|
|
|
298
|
-
self.prediction_layer =
|
|
299
|
+
self.prediction_layer = TaskHead(task_type=self.task)
|
|
299
300
|
|
|
300
301
|
modules = ["mapping", "layers", "w", "w_im"]
|
|
301
302
|
if self.use_linear:
|
|
@@ -331,4 +332,4 @@ class EulerNet(BaseModel):
|
|
|
331
332
|
r, p = layer(r, p)
|
|
332
333
|
r_flat = r.reshape(r.size(0), self.num_orders * self.embedding_dim)
|
|
333
334
|
p_flat = p.reshape(p.size(0), self.num_orders * self.embedding_dim)
|
|
334
|
-
return self.w(r_flat) + self.w_im(p_flat)
|
|
335
|
+
return self.w(r_flat) + self.w_im(p_flat)
|
|
@@ -43,7 +43,8 @@ import torch
|
|
|
43
43
|
import torch.nn as nn
|
|
44
44
|
|
|
45
45
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
46
|
-
from nextrec.basic.layers import AveragePooling, InputMask,
|
|
46
|
+
from nextrec.basic.layers import AveragePooling, InputMask, SumPooling
|
|
47
|
+
from nextrec.basic.heads import TaskHead
|
|
47
48
|
from nextrec.basic.model import BaseModel
|
|
48
49
|
from nextrec.utils.torch_utils import get_initializer
|
|
49
50
|
|
|
@@ -140,7 +141,7 @@ class FFM(BaseModel):
|
|
|
140
141
|
nn.Linear(dense_input_dim, 1, bias=True) if dense_input_dim > 0 else None
|
|
141
142
|
)
|
|
142
143
|
|
|
143
|
-
self.prediction_layer =
|
|
144
|
+
self.prediction_layer = TaskHead(task_type=self.task)
|
|
144
145
|
self.input_mask = InputMask()
|
|
145
146
|
self.mean_pool = AveragePooling()
|
|
146
147
|
self.sum_pool = SumPooling()
|
|
@@ -272,4 +273,4 @@ class FFM(BaseModel):
|
|
|
272
273
|
)
|
|
273
274
|
|
|
274
275
|
y = y_linear + y_interaction
|
|
275
|
-
return self.prediction_layer(y)
|
|
276
|
+
return self.prediction_layer(y)
|
|
@@ -50,9 +50,9 @@ from nextrec.basic.layers import (
|
|
|
50
50
|
BiLinearInteractionLayer,
|
|
51
51
|
EmbeddingLayer,
|
|
52
52
|
HadamardInteractionLayer,
|
|
53
|
-
PredictionLayer,
|
|
54
53
|
SENETLayer,
|
|
55
54
|
)
|
|
55
|
+
from nextrec.basic.heads import TaskHead
|
|
56
56
|
from nextrec.basic.model import BaseModel
|
|
57
57
|
|
|
58
58
|
|
|
@@ -168,7 +168,7 @@ class FiBiNET(BaseModel):
|
|
|
168
168
|
num_pairs = self.num_fields * (self.num_fields - 1) // 2
|
|
169
169
|
interaction_dim = num_pairs * self.embedding_dim * 2
|
|
170
170
|
self.mlp = MLP(input_dim=interaction_dim, **mlp_params)
|
|
171
|
-
self.prediction_layer =
|
|
171
|
+
self.prediction_layer = TaskHead(task_type=self.default_task)
|
|
172
172
|
|
|
173
173
|
# Register regularization weights
|
|
174
174
|
self.register_regularization_weights(
|
|
@@ -42,7 +42,8 @@ import torch.nn as nn
|
|
|
42
42
|
|
|
43
43
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
44
44
|
from nextrec.basic.layers import FM as FMInteraction
|
|
45
|
-
from nextrec.basic.
|
|
45
|
+
from nextrec.basic.heads import TaskHead
|
|
46
|
+
from nextrec.basic.layers import LR, EmbeddingLayer
|
|
46
47
|
from nextrec.basic.model import BaseModel
|
|
47
48
|
|
|
48
49
|
|
|
@@ -105,7 +106,7 @@ class FM(BaseModel):
|
|
|
105
106
|
fm_input_dim = sum([f.embedding_dim for f in self.fm_features])
|
|
106
107
|
self.linear = LR(fm_input_dim)
|
|
107
108
|
self.fm = FMInteraction(reduce_sum=True)
|
|
108
|
-
self.prediction_layer =
|
|
109
|
+
self.prediction_layer = TaskHead(task_type=self.task)
|
|
109
110
|
|
|
110
111
|
# Register regularization weights
|
|
111
112
|
self.register_regularization_weights(
|
|
@@ -124,4 +125,4 @@ class FM(BaseModel):
|
|
|
124
125
|
y_linear = self.linear(input_fm.flatten(start_dim=1))
|
|
125
126
|
y_fm = self.fm(input_fm)
|
|
126
127
|
y = y_linear + y_fm
|
|
127
|
-
return self.prediction_layer(y)
|
|
128
|
+
return self.prediction_layer(y)
|
|
@@ -41,7 +41,8 @@ LR 是 CTR/排序任务中最经典的线性基线模型。它将稠密、稀疏
|
|
|
41
41
|
import torch.nn as nn
|
|
42
42
|
|
|
43
43
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
44
|
-
from nextrec.basic.layers import EmbeddingLayer, LR as LinearLayer
|
|
44
|
+
from nextrec.basic.layers import EmbeddingLayer, LR as LinearLayer
|
|
45
|
+
from nextrec.basic.heads import TaskHead
|
|
45
46
|
from nextrec.basic.model import BaseModel
|
|
46
47
|
|
|
47
48
|
|
|
@@ -99,7 +100,7 @@ class LR(BaseModel):
|
|
|
99
100
|
self.embedding = EmbeddingLayer(features=self.all_features)
|
|
100
101
|
linear_input_dim = self.embedding.input_dim
|
|
101
102
|
self.linear = LinearLayer(linear_input_dim)
|
|
102
|
-
self.prediction_layer =
|
|
103
|
+
self.prediction_layer = TaskHead(task_type=self.task)
|
|
103
104
|
|
|
104
105
|
self.register_regularization_weights(
|
|
105
106
|
embedding_attr="embedding", include_modules=["linear"]
|
|
@@ -115,4 +116,4 @@ class LR(BaseModel):
|
|
|
115
116
|
def forward(self, x):
|
|
116
117
|
input_linear = self.embedding(x=x, features=self.all_features, squeeze_dim=True)
|
|
117
118
|
y = self.linear(input_linear)
|
|
118
|
-
return self.prediction_layer(y)
|
|
119
|
+
return self.prediction_layer(y)
|