nextrec 0.4.33__tar.gz → 0.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nextrec-0.4.33 → nextrec-0.5.0}/.gitignore +3 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/PKG-INFO +10 -4
- {nextrec-0.4.33 → nextrec-0.5.0}/README.md +5 -3
- {nextrec-0.4.33 → nextrec-0.5.0}/README_en.md +5 -3
- {nextrec-0.4.33 → nextrec-0.5.0}/docs/rtd/conf.py +1 -1
- {nextrec-0.4.33 → nextrec-0.5.0}/docs/rtd/nextrec.utils.rst +0 -8
- nextrec-0.5.0/nextrec/__version__.py +1 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/basic/activation.py +10 -18
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/basic/asserts.py +1 -22
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/basic/callback.py +2 -2
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/basic/features.py +6 -37
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/basic/heads.py +13 -1
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/basic/layers.py +33 -123
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/basic/loggers.py +3 -2
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/basic/metrics.py +85 -4
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/basic/model.py +518 -7
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/basic/summary.py +88 -42
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/cli.py +117 -30
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/data/data_processing.py +8 -13
- nextrec-0.5.0/nextrec/data/preprocessor.py +1196 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/loss/grad_norm.py +78 -76
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/multi_task/ple.py +1 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/multi_task/share_bottom.py +1 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/afm.py +4 -9
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/dien.py +7 -8
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/ffm.py +2 -2
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/retrieval/sdm.py +1 -2
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/sequential/hstu.py +0 -2
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/tree_base/base.py +1 -1
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/utils/__init__.py +2 -1
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/utils/config.py +1 -1
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/utils/console.py +1 -1
- nextrec-0.5.0/nextrec/utils/onnx_utils.py +252 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/utils/torch_utils.py +63 -56
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/utils/types.py +43 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/NextRec-CLI.md +0 -2
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/NextRec-CLI_zh.md +0 -2
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/predict_config.yaml +6 -3
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/predict_config_template.yaml +6 -2
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/train_config.yaml +5 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/train_config_template.yaml +6 -2
- {nextrec-0.4.33 → nextrec-0.5.0}/pyproject.toml +5 -1
- {nextrec-0.4.33 → nextrec-0.5.0}/requirements.txt +4 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/test/run_tests.py +5 -1
- nextrec-0.5.0/test/test_onnx_models.py +620 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/test/test_preprocessor.py +8 -6
- nextrec-0.5.0/tutorials/distributed/example_distributed_training.py +234 -0
- nextrec-0.5.0/tutorials/distributed/example_distributed_training_large_dataset.py +251 -0
- nextrec-0.5.0/tutorials/example_match.py +261 -0
- nextrec-0.5.0/tutorials/example_multitask.py +221 -0
- nextrec-0.5.0/tutorials/example_onnx.py +300 -0
- nextrec-0.5.0/tutorials/example_ranking_din.py +217 -0
- nextrec-0.5.0/tutorials/example_tree.py +205 -0
- nextrec-0.5.0/tutorials/movielen_match_dssm.py +270 -0
- nextrec-0.5.0/tutorials/movielen_ranking_deepfm.py +183 -0
- nextrec-0.5.0/tutorials/run_all_match_models.py +303 -0
- nextrec-0.5.0/tutorials/run_all_multitask_models.py +396 -0
- nextrec-0.5.0/tutorials/run_all_ranking_models.py +388 -0
- nextrec-0.4.33/nextrec/__version__.py +0 -1
- nextrec-0.4.33/nextrec/data/preprocessor.py +0 -1591
- nextrec-0.4.33/nextrec/models/multi_task/[pre]star.py +0 -192
- nextrec-0.4.33/nextrec/models/representation/autorec.py +0 -0
- nextrec-0.4.33/nextrec/models/representation/bpr.py +0 -0
- nextrec-0.4.33/nextrec/models/representation/cl4srec.py +0 -0
- nextrec-0.4.33/nextrec/models/representation/lightgcn.py +0 -0
- nextrec-0.4.33/nextrec/models/representation/mf.py +0 -0
- nextrec-0.4.33/nextrec/models/representation/s3rec.py +0 -0
- nextrec-0.4.33/nextrec/models/sequential/sasrec.py +0 -0
- nextrec-0.4.33/nextrec/utils/feature.py +0 -29
- nextrec-0.4.33/tutorials/distributed/example_distributed_training.py +0 -158
- nextrec-0.4.33/tutorials/distributed/example_distributed_training_large_dataset.py +0 -158
- nextrec-0.4.33/tutorials/example_match.py +0 -164
- nextrec-0.4.33/tutorials/example_multitask.py +0 -122
- nextrec-0.4.33/tutorials/example_ranking_din.py +0 -125
- nextrec-0.4.33/tutorials/example_tree.py +0 -97
- nextrec-0.4.33/tutorials/movielen_match_dssm.py +0 -155
- nextrec-0.4.33/tutorials/movielen_ranking_deepfm.py +0 -73
- nextrec-0.4.33/tutorials/run_all_match_models.py +0 -210
- nextrec-0.4.33/tutorials/run_all_multitask_models.py +0 -285
- nextrec-0.4.33/tutorials/run_all_ranking_models.py +0 -264
- {nextrec-0.4.33 → nextrec-0.5.0}/.github/workflows/publish.yml +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/.github/workflows/tests.yml +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/.readthedocs.yaml +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/CODE_OF_CONDUCT.md +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/CONTRIBUTING.md +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/LICENSE +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/MANIFEST.in +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/assets/Feature Configuration.png +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/assets/Model Parameters.png +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/assets/Training Configuration.png +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/assets/Training logs.png +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/assets/logo.png +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/assets/mmoe_tutorial.png +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/assets/nextrec_diagram.png +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/assets/test data.png +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/dataset/ctcvr_task.csv +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/dataset/ecommerce_task.csv +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/dataset/match_task.csv +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/dataset/movielens_100k.csv +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/dataset/multitask_task.csv +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/dataset/ranking_task.csv +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/docs/en/Getting started guide.md +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/docs/rtd/Makefile +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/docs/rtd/index.md +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/docs/rtd/make.bat +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/docs/rtd/modules.rst +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/docs/rtd/nextrec.basic.rst +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/docs/rtd/nextrec.data.rst +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/docs/rtd/nextrec.loss.rst +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/docs/rtd/nextrec.rst +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/docs/rtd/requirements.txt +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/docs/zh//345/277/253/351/200/237/344/270/212/346/211/213.md" +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/__init__.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/basic/__init__.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/basic/session.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/data/__init__.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/data/batch_utils.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/data/data_utils.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/data/dataloader.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/loss/__init__.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/loss/listwise.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/loss/pairwise.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/loss/pointwise.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/generative/__init__.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/generative/tiger.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/multi_task/[pre]aitm.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/multi_task/[pre]snr_trans.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/multi_task/__init__.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/multi_task/apg.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/multi_task/cross_stitch.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/multi_task/escm.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/multi_task/esmm.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/multi_task/hmoe.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/multi_task/mmoe.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/multi_task/pepnet.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/multi_task/poso.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/__init__.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/autoint.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/dcn.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/dcn_v2.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/deepfm.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/din.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/eulernet.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/fibinet.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/fm.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/lr.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/masknet.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/pnn.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/widedeep.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/ranking/xdeepfm.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/representation/__init__.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/representation/rqvae.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/retrieval/__init__.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/retrieval/dssm.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/retrieval/dssm_v2.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/retrieval/mind.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/retrieval/youtube_dnn.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/tree_base/__init__.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/tree_base/catboost.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/tree_base/lightgbm.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/models/tree_base/xgboost.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/utils/data.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/utils/embedding.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/utils/loss.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec/utils/model.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/feature_config.yaml +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/afm.yaml +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/apg.yaml +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/autoint.yaml +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/cross_stitch.yaml +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/dcn.yaml +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/deepfm.yaml +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/din.yaml +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/escm.yaml +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/esmm.yaml +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/fibinet.yaml +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/fm.yaml +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/hmoe.yaml +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/masknet.yaml +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/mmoe.yaml +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/pepnet.yaml +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/ple.yaml +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/pnn.yaml +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/poso.yaml +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/share_bottom.yaml +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/widedeep.yaml +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/nextrec_cli_preset/model_configs/xdeepfm.yaml +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/pytest.ini +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/scripts/format_code.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/test/__init__.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/test/conftest.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/test/helpers.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/test/test_base_model_regularization.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/test/test_generative_models.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/test/test_layers.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/test/test_losses.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/test/test_match_models.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/test/test_multitask_models.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/test/test_ranking_models.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/test/test_utils_console.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/test/test_utils_data.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/test/test_utils_embedding.py +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/test_requirements.txt +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/tutorials/notebooks/en/Build semantic ID with RQ-VAE.ipynb +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/tutorials/notebooks/en/Hands on dataprocessor.ipynb +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/tutorials/notebooks/en/Hands on nextrec.ipynb +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/tutorials/notebooks/zh//344/275/277/347/224/250RQ-VAE/346/236/204/345/273/272/350/257/255/344/271/211ID.ipynb" +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/tutorials/notebooks/zh//345/246/202/344/275/225/344/275/277/347/224/250DataProcessor/350/277/233/350/241/214/351/242/204/345/244/204/347/220/206.ipynb" +0 -0
- {nextrec-0.4.33 → nextrec-0.5.0}/tutorials/notebooks/zh//345/277/253/351/200/237/345/205/245/351/227/250nextrec.ipynb" +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nextrec
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.5.0
|
|
4
4
|
Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
|
|
5
5
|
Project-URL: Homepage, https://github.com/zerolovesea/NextRec
|
|
6
6
|
Project-URL: Repository, https://github.com/zerolovesea/NextRec
|
|
@@ -24,10 +24,14 @@ Requires-Dist: numpy<2.0,>=1.21; sys_platform == 'linux' and python_version < '3
|
|
|
24
24
|
Requires-Dist: numpy<3.0,>=1.26; sys_platform == 'linux' and python_version >= '3.12'
|
|
25
25
|
Requires-Dist: numpy>=1.23.0; sys_platform == 'win32'
|
|
26
26
|
Requires-Dist: numpy>=1.24.0; sys_platform == 'darwin'
|
|
27
|
+
Requires-Dist: onnx>=1.16.0
|
|
28
|
+
Requires-Dist: onnxruntime>=1.18.0
|
|
29
|
+
Requires-Dist: onnxscript>=0.1.1
|
|
27
30
|
Requires-Dist: pandas<2.0,>=1.5; sys_platform == 'linux' and python_version < '3.12'
|
|
28
31
|
Requires-Dist: pandas<2.3.0,>=2.1.0; sys_platform == 'win32'
|
|
29
32
|
Requires-Dist: pandas>=2.0.0; sys_platform == 'darwin'
|
|
30
33
|
Requires-Dist: pandas>=2.1.0; sys_platform == 'linux' and python_version >= '3.12'
|
|
34
|
+
Requires-Dist: polars>=0.20.0
|
|
31
35
|
Requires-Dist: pyarrow<13.0.0,>=10.0.0; sys_platform == 'linux' and python_version < '3.12'
|
|
32
36
|
Requires-Dist: pyarrow<15.0.0,>=12.0.0; sys_platform == 'win32'
|
|
33
37
|
Requires-Dist: pyarrow>=12.0.0; sys_platform == 'darwin'
|
|
@@ -69,7 +73,7 @@ Description-Content-Type: text/markdown
|
|
|
69
73
|

|
|
70
74
|

|
|
71
75
|

|
|
72
|
-

|
|
73
77
|
[](https://deepwiki.com/zerolovesea/NextRec)
|
|
74
78
|
|
|
75
79
|
中文文档 | [English Version](README_en.md)
|
|
@@ -102,6 +106,7 @@ NextRec是一个基于PyTorch的现代推荐系统框架,旨在为研究工程
|
|
|
102
106
|
- **高效训练与评估**:内置多种优化器、学习率调度、早停、模型检查点与详细的日志管理,开箱即用。
|
|
103
107
|
|
|
104
108
|
## NextRec近期进展
|
|
109
|
+
- **28/01/2026** 在v0.4.39中加入了对onnx导出和加载的支持,并大大加速了数据预处理速度(最高9x加速)
|
|
105
110
|
- **01/01/2026** 新年好,在v0.4.27中加入了多个多目标模型的支持:[APG](nextrec/models/multi_task/apg.py), [ESCM](nextrec/models/multi_task/escm.py), [HMoE](nextrec/models/multi_task/hmoe.py), [Cross Stitch](nextrec/models/multi_task/cross_stitch.py)
|
|
106
111
|
- **28/12/2025** 在v0.4.21中加入了对SwanLab和Wandb的支持,通过model的`fit`方法进行配置:`use_swanlab=True, swanlab_kwargs={"project": "NextRec","name":"tutorial_movielens_deepfm"},`
|
|
107
112
|
- **21/12/2025** 在v0.4.16中加入了对[GradNorm](/nextrec/loss/grad_norm.py)的支持,通过compile的`loss_weight='grad_norm'`进行配置
|
|
@@ -136,6 +141,7 @@ pip install nextrec # or pip install -e .
|
|
|
136
141
|
- [example_multitask.py](/tutorials/example_multitask.py) - 电商数据集上的ESMM多任务学习训练示例
|
|
137
142
|
- [movielen_match_dssm.py](/tutorials/movielen_match_dssm.py) - 基于movielen 100k数据集训练的 DSSM 召回模型示例
|
|
138
143
|
|
|
144
|
+
- [example_onnx.py](/tutorials/example_onnx.py) - 使用NextRec训练和导出onnx模型
|
|
139
145
|
- [example_distributed_training.py](/tutorials/distributed/example_distributed_training.py) - 使用NextRec进行单机多卡训练的代码示例
|
|
140
146
|
|
|
141
147
|
- [run_all_ranking_models.py](/tutorials/run_all_ranking_models.py) - 快速校验所有排序模型的可用性
|
|
@@ -254,11 +260,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
|
|
|
254
260
|
|
|
255
261
|
预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
|
|
256
262
|
|
|
257
|
-
> 截止当前版本0.
|
|
263
|
+
> 截止当前版本0.5.0,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
|
|
258
264
|
|
|
259
265
|
## 兼容平台
|
|
260
266
|
|
|
261
|
-
当前最新版本为0.
|
|
267
|
+
当前最新版本为0.5.0,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
|
|
262
268
|
|
|
263
269
|
| 平台 | 配置 |
|
|
264
270
|
|------|------|
|
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|

|
|
9
9
|

|
|
10
10
|

|
|
11
|
-

|
|
12
12
|
[](https://deepwiki.com/zerolovesea/NextRec)
|
|
13
13
|
|
|
14
14
|
中文文档 | [English Version](README_en.md)
|
|
@@ -41,6 +41,7 @@ NextRec是一个基于PyTorch的现代推荐系统框架,旨在为研究工程
|
|
|
41
41
|
- **高效训练与评估**:内置多种优化器、学习率调度、早停、模型检查点与详细的日志管理,开箱即用。
|
|
42
42
|
|
|
43
43
|
## NextRec近期进展
|
|
44
|
+
- **28/01/2026** 在v0.4.39中加入了对onnx导出和加载的支持,并大大加速了数据预处理速度(最高9x加速)
|
|
44
45
|
- **01/01/2026** 新年好,在v0.4.27中加入了多个多目标模型的支持:[APG](nextrec/models/multi_task/apg.py), [ESCM](nextrec/models/multi_task/escm.py), [HMoE](nextrec/models/multi_task/hmoe.py), [Cross Stitch](nextrec/models/multi_task/cross_stitch.py)
|
|
45
46
|
- **28/12/2025** 在v0.4.21中加入了对SwanLab和Wandb的支持,通过model的`fit`方法进行配置:`use_swanlab=True, swanlab_kwargs={"project": "NextRec","name":"tutorial_movielens_deepfm"},`
|
|
46
47
|
- **21/12/2025** 在v0.4.16中加入了对[GradNorm](/nextrec/loss/grad_norm.py)的支持,通过compile的`loss_weight='grad_norm'`进行配置
|
|
@@ -75,6 +76,7 @@ pip install nextrec # or pip install -e .
|
|
|
75
76
|
- [example_multitask.py](/tutorials/example_multitask.py) - 电商数据集上的ESMM多任务学习训练示例
|
|
76
77
|
- [movielen_match_dssm.py](/tutorials/movielen_match_dssm.py) - 基于movielen 100k数据集训练的 DSSM 召回模型示例
|
|
77
78
|
|
|
79
|
+
- [example_onnx.py](/tutorials/example_onnx.py) - 使用NextRec训练和导出onnx模型
|
|
78
80
|
- [example_distributed_training.py](/tutorials/distributed/example_distributed_training.py) - 使用NextRec进行单机多卡训练的代码示例
|
|
79
81
|
|
|
80
82
|
- [run_all_ranking_models.py](/tutorials/run_all_ranking_models.py) - 快速校验所有排序模型的可用性
|
|
@@ -193,11 +195,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
|
|
|
193
195
|
|
|
194
196
|
预测结果固定保存到 `{checkpoint_path}/predictions/{name}.{save_data_format}`。
|
|
195
197
|
|
|
196
|
-
> 截止当前版本0.
|
|
198
|
+
> 截止当前版本0.5.0,NextRec CLI支持单机训练,分布式训练相关功能尚在开发中。
|
|
197
199
|
|
|
198
200
|
## 兼容平台
|
|
199
201
|
|
|
200
|
-
当前最新版本为0.
|
|
202
|
+
当前最新版本为0.5.0,所有模型和测试代码均已在以下平台通过验证,如果开发者在使用中遇到兼容问题,请在issue区提出错误报告及系统版本:
|
|
201
203
|
|
|
202
204
|
| 平台 | 配置 |
|
|
203
205
|
|------|------|
|
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|

|
|
9
9
|

|
|
10
10
|

|
|
11
|
-

|
|
12
12
|
[](https://deepwiki.com/zerolovesea/NextRec)
|
|
13
13
|
|
|
14
14
|
English | [中文文档](README.md)
|
|
@@ -44,6 +44,7 @@ NextRec is a modern recommendation framework built on PyTorch, delivering a unif
|
|
|
44
44
|
|
|
45
45
|
## NextRec Progress
|
|
46
46
|
|
|
47
|
+
- **28/01/2026** Added support for ONNX export and loading in v0.4.39, and significantly accelerated data preprocessing speed (up to 9x speedup)
|
|
47
48
|
- **01/01/2026** Happy New Year! In v0.4.27, added support for multiple multi-task models: [APG](/nextrec/models/multi_task/apg.py), [ESCM](/nextrec/models/multi_task/escm.py), [HMoE](/nextrec/models/multi_task/hmoe.py), [Cross Stitch](/nextrec/models/multi_task/cross_stitch.py)
|
|
48
49
|
- **28/12/2025** Added support for SwanLab and Weights & Biases in v0.4.21, configurable via the model `fit` method: `use_swanlab=True, swanlab_kwargs={"project": "NextRec","name":"tutorial_movielens_deepfm"},`
|
|
49
50
|
- **21/12/2025** Added support for [GradNorm](/nextrec/loss/grad_norm.py) in v0.4.16, configurable via `loss_weight='grad_norm'` in the compile method
|
|
@@ -79,6 +80,7 @@ See `tutorials/` for examples covering ranking, retrieval, multi-task learning,
|
|
|
79
80
|
- [example_multitask.py](/tutorials/example_multitask.py) — ESMM multi-task learning training on e-commerce dataset
|
|
80
81
|
- [movielen_match_dssm.py](/tutorials/movielen_match_dssm.py) — DSSM retrieval model training on MovieLens 100k dataset
|
|
81
82
|
|
|
83
|
+
- [example_onnx.py](/tutorials/example_onnx.py) — Train and export models to ONNX format with NextRec
|
|
82
84
|
- [example_distributed_training.py](/tutorials/distributed/example_distributed_training.py) — Single-machine multi-GPU training with NextRec
|
|
83
85
|
|
|
84
86
|
- [run_all_ranking_models.py](/tutorials/run_all_ranking_models.py) — Quickly validate availability of all ranking models
|
|
@@ -196,11 +198,11 @@ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
|
|
|
196
198
|
|
|
197
199
|
Prediction outputs are saved under `{checkpoint_path}/predictions/{name}.{save_data_format}`.
|
|
198
200
|
|
|
199
|
-
> As of version 0.
|
|
201
|
+
> As of version 0.5.0, NextRec CLI supports single-machine training; distributed training features are currently under development.
|
|
200
202
|
|
|
201
203
|
## Platform Compatibility
|
|
202
204
|
|
|
203
|
-
The current version is 0.
|
|
205
|
+
The current version is 0.5.0. All models and test code have been validated on the following platforms. If you encounter compatibility issues, please report them in the issue tracker with your system version:
|
|
204
206
|
|
|
205
207
|
| Platform | Configuration |
|
|
206
208
|
|----------|---------------|
|
|
@@ -28,14 +28,6 @@ nextrec.utils.data module
|
|
|
28
28
|
:undoc-members:
|
|
29
29
|
:show-inheritance:
|
|
30
30
|
|
|
31
|
-
nextrec.utils.feature module
|
|
32
|
-
----------------------------
|
|
33
|
-
|
|
34
|
-
.. automodule:: nextrec.utils.feature
|
|
35
|
-
:members:
|
|
36
|
-
:undoc-members:
|
|
37
|
-
:show-inheritance:
|
|
38
|
-
|
|
39
31
|
nextrec.utils.model module
|
|
40
32
|
--------------------------
|
|
41
33
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.5.0"
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Activation function definitions
|
|
2
|
+
Activation function definitions.
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 20/01/2026
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -22,26 +22,18 @@ class Dice(nn.Module):
|
|
|
22
22
|
where p(x) = sigmoid((x - E[x]) / sqrt(Var[x] + epsilon))
|
|
23
23
|
"""
|
|
24
24
|
|
|
25
|
-
def __init__(self, emb_size: int, epsilon: float = 1e-
|
|
25
|
+
def __init__(self, emb_size: int, epsilon: float = 1e-3):
|
|
26
26
|
super(Dice, self).__init__()
|
|
27
|
-
self.epsilon = epsilon
|
|
28
27
|
self.alpha = nn.Parameter(torch.zeros(emb_size))
|
|
29
|
-
self.bn = nn.BatchNorm1d(emb_size)
|
|
28
|
+
self.bn = nn.BatchNorm1d(emb_size, eps=epsilon, affine=False)
|
|
30
29
|
|
|
31
30
|
def forward(self, x):
|
|
32
|
-
#
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
x = x.view(-1, emb_size)
|
|
39
|
-
x_norm = self.bn(x)
|
|
40
|
-
p = torch.sigmoid(x_norm)
|
|
41
|
-
output = p * x + (1 - p) * self.alpha * x
|
|
42
|
-
if len(original_shape) == 3:
|
|
43
|
-
output = output.view(original_shape)
|
|
44
|
-
return output
|
|
31
|
+
# keep original shape for reshaping back after batch norm
|
|
32
|
+
orig_shape = x.shape # x: [N, L, emb_size] or [N, emb_size]
|
|
33
|
+
x2 = x.reshape(-1, orig_shape[-1]) # x2:[N*L, emb_size] or [N, emb_size]
|
|
34
|
+
x_norm = self.bn(x2)
|
|
35
|
+
p = torch.sigmoid(x_norm).reshape(orig_shape)
|
|
36
|
+
return x * (self.alpha + (1 - self.alpha) * p)
|
|
45
37
|
|
|
46
38
|
|
|
47
39
|
def activation_layer(
|
|
@@ -8,7 +8,7 @@ Author: Yang Zhou, zyaztec@gmail.com
|
|
|
8
8
|
|
|
9
9
|
from __future__ import annotations
|
|
10
10
|
|
|
11
|
-
from nextrec.utils.types import TaskTypeName
|
|
11
|
+
from nextrec.utils.types import TaskTypeName
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
def assert_task(
|
|
@@ -49,24 +49,3 @@ def assert_task(
|
|
|
49
49
|
raise ValueError(
|
|
50
50
|
f"{model_name} requires task length {nums_task}, got {len(task)}."
|
|
51
51
|
)
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
def assert_training_mode(
|
|
55
|
-
training_mode: TrainingModeName | list[TrainingModeName],
|
|
56
|
-
nums_task: int,
|
|
57
|
-
*,
|
|
58
|
-
model_name: str,
|
|
59
|
-
) -> None:
|
|
60
|
-
valid_modes = {"pointwise", "pairwise", "listwise"}
|
|
61
|
-
if not isinstance(training_mode, list):
|
|
62
|
-
raise TypeError(
|
|
63
|
-
f"[{model_name}-init Error] training_mode must be a list with length {nums_task}."
|
|
64
|
-
)
|
|
65
|
-
if len(training_mode) != nums_task:
|
|
66
|
-
raise ValueError(
|
|
67
|
-
f"[{model_name}-init Error] training_mode list length must match number of tasks."
|
|
68
|
-
)
|
|
69
|
-
if any(mode not in valid_modes for mode in training_mode):
|
|
70
|
-
raise ValueError(
|
|
71
|
-
f"[{model_name}-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
|
|
72
|
-
)
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Callback System for Training Process
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 21/01/2026
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -69,7 +69,7 @@ class Callback:
|
|
|
69
69
|
|
|
70
70
|
class CallbackList:
|
|
71
71
|
"""
|
|
72
|
-
Generates a list of callbacks
|
|
72
|
+
Generates a list of callbacks, used to manage and invoke multiple callbacks during training.
|
|
73
73
|
"""
|
|
74
74
|
|
|
75
75
|
def __init__(self, callbacks: Optional[list[Callback]] = None):
|
|
@@ -8,10 +8,9 @@ Author: Yang Zhou, zyaztec@gmail.com
|
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
|
|
11
|
-
from typing import Literal
|
|
12
|
-
|
|
13
11
|
from nextrec.utils.embedding import get_auto_embedding_dim
|
|
14
|
-
from nextrec.utils.
|
|
12
|
+
from nextrec.utils.torch_utils import to_list
|
|
13
|
+
from nextrec.utils.types import EmbeddingInitType, SequenceCombinerType
|
|
15
14
|
|
|
16
15
|
|
|
17
16
|
class BaseFeature:
|
|
@@ -29,15 +28,7 @@ class EmbeddingFeature(BaseFeature):
|
|
|
29
28
|
embedding_name: str = "",
|
|
30
29
|
embedding_dim: int | None = None,
|
|
31
30
|
padding_idx: int = 0,
|
|
32
|
-
init_type:
|
|
33
|
-
"normal",
|
|
34
|
-
"uniform",
|
|
35
|
-
"xavier_uniform",
|
|
36
|
-
"xavier_normal",
|
|
37
|
-
"kaiming_uniform",
|
|
38
|
-
"kaiming_normal",
|
|
39
|
-
"orthogonal",
|
|
40
|
-
] = "normal",
|
|
31
|
+
init_type: EmbeddingInitType = "normal",
|
|
41
32
|
init_params: dict | None = None,
|
|
42
33
|
l1_reg: float = 0.0,
|
|
43
34
|
l2_reg: float = 0.0,
|
|
@@ -73,23 +64,9 @@ class SequenceFeature(EmbeddingFeature):
|
|
|
73
64
|
max_len: int = 50,
|
|
74
65
|
embedding_name: str = "",
|
|
75
66
|
embedding_dim: int | None = None,
|
|
76
|
-
combiner:
|
|
77
|
-
"mean",
|
|
78
|
-
"sum",
|
|
79
|
-
"concat",
|
|
80
|
-
"dot_attention",
|
|
81
|
-
"self_attention",
|
|
82
|
-
] = "mean",
|
|
67
|
+
combiner: SequenceCombinerType = "mean",
|
|
83
68
|
padding_idx: int = 0,
|
|
84
|
-
init_type:
|
|
85
|
-
"normal",
|
|
86
|
-
"uniform",
|
|
87
|
-
"xavier_uniform",
|
|
88
|
-
"xavier_normal",
|
|
89
|
-
"kaiming_uniform",
|
|
90
|
-
"kaiming_normal",
|
|
91
|
-
"orthogonal",
|
|
92
|
-
] = "normal",
|
|
69
|
+
init_type: EmbeddingInitType = "normal",
|
|
93
70
|
init_params: dict | None = None,
|
|
94
71
|
l1_reg: float = 0.0,
|
|
95
72
|
l2_reg: float = 0.0,
|
|
@@ -143,15 +120,7 @@ class SparseFeature(EmbeddingFeature):
|
|
|
143
120
|
embedding_name: str = "",
|
|
144
121
|
embedding_dim: int | None = None,
|
|
145
122
|
padding_idx: int = 0,
|
|
146
|
-
init_type:
|
|
147
|
-
"normal",
|
|
148
|
-
"uniform",
|
|
149
|
-
"xavier_uniform",
|
|
150
|
-
"xavier_normal",
|
|
151
|
-
"kaiming_uniform",
|
|
152
|
-
"kaiming_normal",
|
|
153
|
-
"orthogonal",
|
|
154
|
-
] = "normal",
|
|
123
|
+
init_type: EmbeddingInitType = "normal",
|
|
155
124
|
init_params: dict | None = None,
|
|
156
125
|
l1_reg: float = 0.0,
|
|
157
126
|
l2_reg: float = 0.0,
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Task head implementations for NextRec models.
|
|
3
3
|
|
|
4
4
|
Date: create on 23/12/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 22/01/2026
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -24,6 +24,12 @@ class TaskHead(nn.Module):
|
|
|
24
24
|
|
|
25
25
|
This wraps PredictionLayer so models can depend on a "Head" abstraction
|
|
26
26
|
without changing their existing forward signatures.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
task_type: The type of task(s) this head is responsible for.
|
|
30
|
+
task_dims: The dimensionality of each task's output.
|
|
31
|
+
use_bias: Whether to include a bias term in the prediction layer.
|
|
32
|
+
return_logits: Whether to return raw logits or apply activation.
|
|
27
33
|
"""
|
|
28
34
|
|
|
29
35
|
def __init__(
|
|
@@ -56,6 +62,12 @@ class RetrievalHead(nn.Module):
|
|
|
56
62
|
|
|
57
63
|
It computes similarity for pointwise training/inference, and returns
|
|
58
64
|
raw embeddings for in-batch negative sampling in pairwise/listwise modes.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
similarity_metric: The metric used to compute similarity between embeddings.
|
|
68
|
+
temperature: Scaling factor for similarity scores.
|
|
69
|
+
training_mode: The training mode, which can be pointwise, pairwise, or listwise.
|
|
70
|
+
apply_sigmoid: Whether to apply sigmoid activation to the similarity scores in pointwise mode.
|
|
59
71
|
"""
|
|
60
72
|
|
|
61
73
|
def __init__(
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Layer implementations used across NextRec.
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 25/01/2026
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -20,15 +20,13 @@ import torch.nn.functional as F
|
|
|
20
20
|
from nextrec.basic.activation import activation_layer
|
|
21
21
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
22
22
|
from nextrec.utils.torch_utils import get_initializer
|
|
23
|
-
from nextrec.utils.types import ActivationName
|
|
23
|
+
from nextrec.utils.types import ActivationName, TaskTypeName
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class PredictionLayer(nn.Module):
|
|
27
27
|
def __init__(
|
|
28
28
|
self,
|
|
29
|
-
task_type:
|
|
30
|
-
Literal["binary", "regression"] | list[Literal["binary", "regression"]]
|
|
31
|
-
) = "binary",
|
|
29
|
+
task_type: TaskTypeName | list[TaskTypeName] = "binary",
|
|
32
30
|
task_dims: int | list[int] | None = None,
|
|
33
31
|
use_bias: bool = True,
|
|
34
32
|
return_logits: bool = False,
|
|
@@ -81,10 +79,12 @@ class PredictionLayer(nn.Module):
|
|
|
81
79
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
82
80
|
if x.dim() == 1:
|
|
83
81
|
x = x.unsqueeze(0) # (1 * total_dim)
|
|
84
|
-
if
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
82
|
+
if not torch.onnx.is_in_onnx_export():
|
|
83
|
+
if x.shape[-1] != self.total_dim:
|
|
84
|
+
raise ValueError(
|
|
85
|
+
f"[PredictionLayer Error]: Input last dimension ({x.shape[-1]}) does not match expected total dimension ({self.total_dim})."
|
|
86
|
+
)
|
|
87
|
+
|
|
88
88
|
logits = x if self.bias is None else x + self.bias
|
|
89
89
|
outputs = []
|
|
90
90
|
for task_type, (start, end) in zip(self.task_types, self.task_slices):
|
|
@@ -92,10 +92,9 @@ class PredictionLayer(nn.Module):
|
|
|
92
92
|
if self.return_logits:
|
|
93
93
|
outputs.append(task_logits)
|
|
94
94
|
continue
|
|
95
|
-
|
|
96
|
-
if task == "binary":
|
|
95
|
+
if task_type == "binary":
|
|
97
96
|
outputs.append(torch.sigmoid(task_logits))
|
|
98
|
-
elif
|
|
97
|
+
elif task_type == "regression":
|
|
99
98
|
outputs.append(task_logits)
|
|
100
99
|
else:
|
|
101
100
|
raise ValueError(
|
|
@@ -219,7 +218,7 @@ class EmbeddingLayer(nn.Module):
|
|
|
219
218
|
|
|
220
219
|
elif isinstance(feature, SequenceFeature):
|
|
221
220
|
seq_input = x[feature.name].long()
|
|
222
|
-
if feature.max_len is not None
|
|
221
|
+
if feature.max_len is not None:
|
|
223
222
|
seq_input = seq_input[:, -feature.max_len :]
|
|
224
223
|
|
|
225
224
|
embed = self.embed_dict[feature.embedding_name]
|
|
@@ -282,10 +281,11 @@ class EmbeddingLayer(nn.Module):
|
|
|
282
281
|
value = value.view(value.size(0), -1) # [B, input_dim]
|
|
283
282
|
input_dim = feature.input_dim
|
|
284
283
|
assert_input_dim = self.dense_input_dims.get(feature.name, input_dim)
|
|
285
|
-
if
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
284
|
+
if not torch.onnx.is_in_onnx_export():
|
|
285
|
+
if value.shape[1] != assert_input_dim:
|
|
286
|
+
raise ValueError(
|
|
287
|
+
f"[EmbeddingLayer Error]:Dense feature '{feature.name}' expects {assert_input_dim} inputs but got {value.shape[1]}."
|
|
288
|
+
)
|
|
289
289
|
if not feature.use_projection:
|
|
290
290
|
return value
|
|
291
291
|
dense_layer = self.dense_transforms[feature.name]
|
|
@@ -331,29 +331,10 @@ class InputMask(nn.Module):
|
|
|
331
331
|
feature: SequenceFeature,
|
|
332
332
|
seq_tensor: torch.Tensor | None = None,
|
|
333
333
|
):
|
|
334
|
-
if seq_tensor is not None
|
|
335
|
-
|
|
336
|
-
else:
|
|
337
|
-
values = x[feature.name]
|
|
338
|
-
values = values.long()
|
|
334
|
+
values = seq_tensor if seq_tensor is not None else x[feature.name]
|
|
335
|
+
values = values.long().view(values.size(0), -1)
|
|
339
336
|
padding_idx = feature.padding_idx if feature.padding_idx is not None else 0
|
|
340
|
-
mask = values != padding_idx
|
|
341
|
-
|
|
342
|
-
if mask.dim() == 1:
|
|
343
|
-
# [B] -> [B, 1, 1]
|
|
344
|
-
mask = mask.unsqueeze(1).unsqueeze(2)
|
|
345
|
-
elif mask.dim() == 2:
|
|
346
|
-
# [B, L] -> [B, 1, L]
|
|
347
|
-
mask = mask.unsqueeze(1)
|
|
348
|
-
elif mask.dim() == 3:
|
|
349
|
-
# [B, 1, L]
|
|
350
|
-
# [B, L, 1] -> [B, L] -> [B, 1, L]
|
|
351
|
-
if mask.size(1) != 1 and mask.size(2) == 1:
|
|
352
|
-
mask = mask.squeeze(-1).unsqueeze(1)
|
|
353
|
-
else:
|
|
354
|
-
raise ValueError(
|
|
355
|
-
f"InputMask only supports 1D/2D/3D tensors, got shape {values.shape}"
|
|
356
|
-
)
|
|
337
|
+
mask = (values != padding_idx).unsqueeze(1)
|
|
357
338
|
return mask.float()
|
|
358
339
|
|
|
359
340
|
|
|
@@ -897,30 +878,7 @@ class AttentionPoolingLayer(nn.Module):
|
|
|
897
878
|
self,
|
|
898
879
|
embedding_dim: int,
|
|
899
880
|
hidden_units: list = [80, 40],
|
|
900
|
-
activation:
|
|
901
|
-
"dice",
|
|
902
|
-
"relu",
|
|
903
|
-
"relu6",
|
|
904
|
-
"elu",
|
|
905
|
-
"selu",
|
|
906
|
-
"leaky_relu",
|
|
907
|
-
"prelu",
|
|
908
|
-
"gelu",
|
|
909
|
-
"sigmoid",
|
|
910
|
-
"tanh",
|
|
911
|
-
"softplus",
|
|
912
|
-
"softsign",
|
|
913
|
-
"hardswish",
|
|
914
|
-
"mish",
|
|
915
|
-
"silu",
|
|
916
|
-
"swish",
|
|
917
|
-
"hardsigmoid",
|
|
918
|
-
"tanhshrink",
|
|
919
|
-
"softshrink",
|
|
920
|
-
"none",
|
|
921
|
-
"linear",
|
|
922
|
-
"identity",
|
|
923
|
-
] = "sigmoid",
|
|
881
|
+
activation: ActivationName = "sigmoid",
|
|
924
882
|
use_softmax: bool = False,
|
|
925
883
|
):
|
|
926
884
|
super().__init__()
|
|
@@ -954,39 +912,22 @@ class AttentionPoolingLayer(nn.Module):
|
|
|
954
912
|
output: [batch_size, embedding_dim] - attention pooled representation
|
|
955
913
|
"""
|
|
956
914
|
batch_size, sequence_length, embedding_dim = keys.shape
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
# keys_length: (batch_size,)
|
|
963
|
-
device = keys.device
|
|
964
|
-
seq_range = torch.arange(sequence_length, device=device).unsqueeze(
|
|
965
|
-
0
|
|
966
|
-
) # (1, sequence_length)
|
|
967
|
-
mask = (seq_range < keys_length.unsqueeze(1)).unsqueeze(-1).float()
|
|
968
|
-
if mask is not None:
|
|
969
|
-
if mask.dim() == 2:
|
|
970
|
-
# (B, L)
|
|
971
|
-
mask = mask.unsqueeze(-1)
|
|
972
|
-
elif (
|
|
973
|
-
mask.dim() == 3
|
|
974
|
-
and mask.shape[1] == 1
|
|
975
|
-
and mask.shape[2] == sequence_length
|
|
976
|
-
):
|
|
977
|
-
# (B, 1, L) -> (B, L, 1)
|
|
978
|
-
mask = mask.transpose(1, 2)
|
|
979
|
-
elif (
|
|
980
|
-
mask.dim() == 3
|
|
981
|
-
and mask.shape[1] == sequence_length
|
|
982
|
-
and mask.shape[2] == 1
|
|
983
|
-
):
|
|
984
|
-
pass
|
|
915
|
+
if mask is None:
|
|
916
|
+
if keys_length is None:
|
|
917
|
+
mask = torch.ones(
|
|
918
|
+
(batch_size, sequence_length), device=keys.device, dtype=keys.dtype
|
|
919
|
+
)
|
|
985
920
|
else:
|
|
921
|
+
device = keys.device
|
|
922
|
+
seq_range = torch.arange(sequence_length, device=device).unsqueeze(0)
|
|
923
|
+
mask = (seq_range < keys_length.unsqueeze(1)).to(keys.dtype)
|
|
924
|
+
else:
|
|
925
|
+
mask = mask.to(keys.dtype).reshape(batch_size, -1)
|
|
926
|
+
if mask.shape[1] != sequence_length:
|
|
986
927
|
raise ValueError(
|
|
987
928
|
f"[AttentionPoolingLayer Error]: Unsupported mask shape: {mask.shape}"
|
|
988
929
|
)
|
|
989
|
-
|
|
930
|
+
mask = mask.unsqueeze(-1)
|
|
990
931
|
# Expand query to (B, L, D)
|
|
991
932
|
query_expanded = query.unsqueeze(1).expand(-1, sequence_length, -1)
|
|
992
933
|
# [query, key, query-key, query*key] -> (B, L, 4D)
|
|
@@ -1026,34 +967,3 @@ class RMSNorm(torch.nn.Module):
|
|
|
1026
967
|
variance = torch.mean(x**2, dim=-1, keepdim=True)
|
|
1027
968
|
x_normalized = x * torch.rsqrt(variance + self.eps)
|
|
1028
969
|
return self.weight * x_normalized
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
class DomainBatchNorm(nn.Module):
|
|
1032
|
-
"""Domain-specific BatchNorm (applied per-domain with a shared interface)."""
|
|
1033
|
-
|
|
1034
|
-
def __init__(self, num_features: int, num_domains: int):
|
|
1035
|
-
super().__init__()
|
|
1036
|
-
if num_domains < 1:
|
|
1037
|
-
raise ValueError("num_domains must be >= 1")
|
|
1038
|
-
self.bns = nn.ModuleList(
|
|
1039
|
-
[nn.BatchNorm1d(num_features) for _ in range(num_domains)]
|
|
1040
|
-
)
|
|
1041
|
-
|
|
1042
|
-
def forward(self, x: torch.Tensor, domain_mask: torch.Tensor) -> torch.Tensor:
|
|
1043
|
-
if x.dim() != 2:
|
|
1044
|
-
raise ValueError("DomainBatchNorm expects 2D inputs [B, D].")
|
|
1045
|
-
output = x.clone()
|
|
1046
|
-
if domain_mask.dim() == 1:
|
|
1047
|
-
domain_ids = domain_mask.long()
|
|
1048
|
-
for idx, bn in enumerate(self.bns):
|
|
1049
|
-
mask = domain_ids == idx
|
|
1050
|
-
if mask.any():
|
|
1051
|
-
output[mask] = bn(x[mask])
|
|
1052
|
-
return output
|
|
1053
|
-
if domain_mask.dim() != 2:
|
|
1054
|
-
raise ValueError("domain_mask must be 1D indices or 2D one-hot mask.")
|
|
1055
|
-
for idx, bn in enumerate(self.bns):
|
|
1056
|
-
mask = domain_mask[:, idx] > 0
|
|
1057
|
-
if mask.any():
|
|
1058
|
-
output[mask] = bn(x[mask])
|
|
1059
|
-
return output
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
NextRec Basic Loggers
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 22/01/2026
|
|
6
6
|
Author: Yang Zhou, zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -99,7 +99,8 @@ def format_kv(label: str, value: Any, width: int = 34, indent: int = 0) -> str:
|
|
|
99
99
|
|
|
100
100
|
|
|
101
101
|
def setup_logger(session_id: str | os.PathLike | None = None):
|
|
102
|
-
"""
|
|
102
|
+
"""
|
|
103
|
+
Set up a logger that logs to both console and a file with ANSI formatting.
|
|
103
104
|
Only console output has colors; file output is stripped of ANSI codes.
|
|
104
105
|
|
|
105
106
|
Logs are stored under ``log/<experiment_id>/logs`` by default. A stable
|