torch-rechub 0.0.6__tar.gz → 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/.github/workflows/ci.yml +42 -5
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/CHANGELOG.md +16 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/CONTRIBUTING.md +19 -12
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/PKG-INFO +8 -1
- torch_rechub-0.1.0/docs/en/core/data.md +86 -0
- torch_rechub-0.1.0/docs/en/core/evaluation.md +207 -0
- torch_rechub-0.1.0/docs/en/core/features.md +77 -0
- torch_rechub-0.1.0/docs/en/core/intro.md +30 -0
- torch_rechub-0.1.0/docs/en/models/generative.md +128 -0
- torch_rechub-0.1.0/docs/en/models/intro.md +102 -0
- torch_rechub-0.1.0/docs/en/models/matching.md +72 -0
- torch_rechub-0.1.0/docs/en/models/mtl.md +57 -0
- torch_rechub-0.1.0/docs/en/models/ranking.md +69 -0
- torch_rechub-0.1.0/docs/en/tools/tracking.md +112 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/zh/core/evaluation.md +18 -1
- torch_rechub-0.1.0/docs/zh/serving/onnx.md +143 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/pyproject.toml +7 -3
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/tests/test_parquet_dataset.py +17 -9
- torch_rechub-0.1.0/tests/test_serving.py +324 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/basic/layers.py +213 -150
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/basic/loss_func.py +62 -47
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/data/dataset.py +18 -31
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/generative/hstu.py +48 -33
- torch_rechub-0.1.0/torch_rechub/serving/__init__.py +50 -0
- torch_rechub-0.1.0/torch_rechub/serving/annoy.py +133 -0
- torch_rechub-0.1.0/torch_rechub/serving/base.py +107 -0
- torch_rechub-0.1.0/torch_rechub/serving/faiss.py +154 -0
- torch_rechub-0.1.0/torch_rechub/serving/milvus.py +215 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/trainers/ctr_trainer.py +12 -2
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/trainers/match_trainer.py +13 -2
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/trainers/mtl_trainer.py +12 -2
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/trainers/seq_trainer.py +34 -15
- torch_rechub-0.1.0/torch_rechub/types.py +5 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/utils/data.py +167 -137
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/utils/hstu_utils.py +87 -76
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/utils/model_utils.py +10 -12
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/utils/onnx_export.py +98 -45
- torch_rechub-0.1.0/torch_rechub/utils/quantization.py +128 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/utils/visualization.py +4 -12
- torch_rechub-0.1.0/tutorials/00_QuickStart_CTR_DeepFM.ipynb +300 -0
- torch_rechub-0.1.0/tutorials/01_Ranking_DIN.ipynb +314 -0
- torch_rechub-0.1.0/tutorials/02_Matching_DSSM.ipynb +394 -0
- torch_rechub-0.1.0/tutorials/03_MultiTask_MMOE.ipynb +228 -0
- torch_rechub-0.1.0/tutorials/04_Experiment_Tracking_Light.ipynb +297 -0
- torch_rechub-0.1.0/tutorials/05_Model_Export_and_Serving.ipynb +438 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/uv.lock +951 -65
- torch_rechub-0.0.6/docs/en/core/data.md +0 -0
- torch_rechub-0.0.6/docs/en/core/evaluation.md +0 -0
- torch_rechub-0.0.6/docs/en/core/features.md +0 -0
- torch_rechub-0.0.6/docs/en/models/matching.md +0 -0
- torch_rechub-0.0.6/docs/en/models/mtl.md +0 -0
- torch_rechub-0.0.6/docs/en/models/ranking.md +0 -0
- torch_rechub-0.0.6/docs/en/tools/intro.md +0 -0
- torch_rechub-0.0.6/docs/en/tools/tracking.md +0 -0
- torch_rechub-0.0.6/docs/en/tutorials/intro.md +0 -0
- torch_rechub-0.0.6/docs/zh/serving/onnx.md +0 -9
- torch_rechub-0.0.6/torch_rechub/trainers/matching.md +0 -3
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/.github/ISSUE_TEMPLATE/config.yml +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/.github/ISSUE_TEMPLATE/help_wanted.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/.github/dependabot.yml +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/.github/pull_request_template.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/.github/release.yml +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/.github/workflows/deploy.yml +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/.gitignore +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/.pre-commit-config.yaml +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/CODE_OF_CONDUCT.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/LICENSE +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/README.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/README_en.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/config/.flake8 +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/config/.pep8 +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/config/.pre-commit-config.yaml +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/config/CONFIG_GUIDE.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/config/fix_encoding.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/config/format_code.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/config/pytest.ini +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/.vitepress/config.mts +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/.vitepress/theme/custom.css +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/.vitepress/theme/index.ts +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/.vitepress/theme/style.css +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/cache/api-basic.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/cache/api-models.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/cache/api-trainers.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/cache/api-utils.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/cache/hllm_reproduction.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/cache/hstu_reproduction.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/cache/match.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/cache/rank.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/cache//345/217/202/350/200/203/350/265/204/346/226/231.md" +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/api/api.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/blog/hllm_reproduction.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/blog/match.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/blog/rank.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/community/changelog.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/community/contributing.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/community/faq.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/contributing.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/guide/install.md +0 -0
- {torch_rechub-0.0.6/docs/en/core → torch_rechub-0.1.0/docs/en/guide}/intro.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/guide/quick_start.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/index.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/introduction.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/manual/api-reference/basic.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/manual/api-reference/models.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/manual/api-reference/trainers.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/manual/api-reference/utils.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/manual/faq.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/manual/getting-started.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/manual/installation.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/manual/tutorials/matching.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/manual/tutorials/multi-task.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/manual/tutorials/ranking.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/serving/demo.md +0 -0
- {torch_rechub-0.0.6/docs/en/guide → torch_rechub-0.1.0/docs/en/serving}/intro.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/serving/onnx.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/serving/vector_index.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/tools/callbacks.md +0 -0
- {torch_rechub-0.0.6/docs/en/models → torch_rechub-0.1.0/docs/en/tools}/intro.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/tools/visualization.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/tutorials/ctr.md +0 -0
- {torch_rechub-0.0.6/docs/en/serving → torch_rechub-0.1.0/docs/en/tutorials}/intro.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/tutorials/pipeline.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/en/tutorials/retrieval.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/public/favicon.ico +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/public/img/banner.png +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/public/img/logo.png +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/public/img/logo_with_name.png +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/public/img/project_framework.jpg +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/public/img/win_install_annoy_error.png +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/public/pdf/1606.07792_l8JrVnuYXA.pdf +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/public/pdf/1703.04247_sFSyE7q3U1.pdf +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/public/pdf/1706.06978_0xZD_K10S2.pdf +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/public/pdf/1708.05123_f3lKSqxIvw.pdf +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/public/pdf/1711.00165_eosOSOmTfE.pdf +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/public/pdf/1804.07931_ybf_jOAFRp.pdf +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/public/pdf/1808.09781-3_bmRm284Rxd.pdf +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/public/pdf/1808.09781v1.pdf +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/public/pdf/1905.06336_2oH3RMtROA.pdf +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/public/pdf/2006.11632_qiN67CrHNs.pdf +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/public/pdf/2020 (Tencent) (Recsys) [PLE] Progressive Layered .pdf +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/public/pdf/2102.09267_cdwBFKPCrj.pdf +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/public/pdf/2105.08489-2_XnVVGxN9GG.pdf +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/public/pdf/2203.06801v1-3_qUTY4TbvSL.pdf +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/public/pdf/2959100.2959190_jRzTU81Xmq.pdf +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/public/pdf/3219819.3219950_aTMFXHL3JB.pdf +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/public/pdf/3219819.3220007_zvaZg_CZ6z.pdf +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/public/pdf/4545-Article Text-7584-1-10-20190706.pdf +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/public/pdf/6c8a86c981a62b0126a11896b7f6ae0dae4c3566_1QYYhqJR8.pdf +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/public/pdf/Caruana1997_Article_MultitaskLearning_ySprcjzJ6v.pdf +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/public/pdf/DCN V2 Improved Deep & Cross Network and Practical.pdf +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/public/pdf/cikm2013_DSSM_fullversion_c9ZSdM19XJ.pdf +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/zh/api/api.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/zh/community/changelog.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/zh/community/contributing.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/zh/community/faq.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/zh/core/data.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/zh/core/features.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/zh/core/intro.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/zh/guide/install.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/zh/guide/intro.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/zh/guide/quick_start.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/zh/index.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/zh/models/generative.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/zh/models/intro.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/zh/models/matching.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/zh/models/mtl.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/zh/models/ranking.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/zh/serving/demo.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/zh/serving/intro.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/zh/serving/vector_index.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/zh/tools/callbacks.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/zh/tools/intro.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/zh/tools/tracking.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/zh/tools/visualization.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/zh/tutorials/ctr.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/zh/tutorials/intro.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/zh/tutorials/pipeline.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/docs/zh/tutorials/retrieval.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/generative/data/amazon-books/README.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/generative/data/amazon-books/preprocess_amazon_books.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/generative/data/amazon-books/preprocess_amazon_books_hllm.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/generative/data/ml-1m/README +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/generative/data/ml-1m/preprocess_hllm_data.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/generative/data/ml-1m/preprocess_ml_hstu.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/generative/run_hllm_amazon_books.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/generative/run_hllm_movielens.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/generative/run_hstu_movielens.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/matching/README.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/matching/data/million-song-dataset/process_msd.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/matching/data/ml-1m/preprocess_ml.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/matching/data/session_based/preprocess_session_based.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/matching/data/yidian_news/preprocess.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/matching/movielens_utils.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/matching/run_ml_comirec.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/matching/run_ml_dssm.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/matching/run_ml_facebook_dssm.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/matching/run_ml_gru4rec.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/matching/run_ml_mind.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/matching/run_ml_sine.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/matching/run_ml_youtube_dnn.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/matching/run_ml_youtube_sbc.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/matching/run_sbr.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/ranking/README.md +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/ranking/data/ali-ccp/preprocess_ali_ccp.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/ranking/data/amazon-beauty/preprocess_amazon_beauty.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/ranking/data/amazon-books/preprocess_amazon_books.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/ranking/data/amazon-electronics/preprocess_amazon_electronics.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/ranking/data/avazu/download_avazu.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/ranking/data/census-income/preprocess_census.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/ranking/run_ali_ccp_ctr_ranking.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/ranking/run_ali_ccp_multi_task.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/ranking/run_aliexpress.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/ranking/run_amazon_electronics.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/ranking/run_avazu.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/ranking/run_census.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/ranking/run_criteo.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/ranking/run_gradnorm.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/examples/ranking/run_metabalance.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/package-lock.json +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/package.json +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/tests/test_e2e_matching.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/tests/test_e2e_multitask.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/tests/test_e2e_ranking.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/tests/test_onnx_export.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/tests/test_pa_array_to_tensor.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/tests/test_regularization.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/__init__.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/basic/__init__.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/basic/activation.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/basic/callback.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/basic/features.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/basic/initializers.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/basic/metaoptimizer.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/basic/metric.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/basic/tracking.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/data/__init__.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/data/convert.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/__init__.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/generative/__init__.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/generative/hllm.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/matching/__init__.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/matching/comirec.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/matching/dssm.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/matching/dssm_facebook.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/matching/dssm_senet.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/matching/gru4rec.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/matching/mind.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/matching/narm.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/matching/sasrec.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/matching/sine.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/matching/stamp.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/matching/youtube_dnn.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/matching/youtube_sbc.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/multi_task/__init__.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/multi_task/aitm.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/multi_task/esmm.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/multi_task/mmoe.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/multi_task/ple.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/multi_task/shared_bottom.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/ranking/__init__.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/ranking/afm.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/ranking/autoint.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/ranking/bst.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/ranking/dcn.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/ranking/dcn_v2.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/ranking/deepffm.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/ranking/deepfm.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/ranking/dien.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/ranking/din.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/ranking/edcn.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/ranking/fibinet.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/models/ranking/widedeep.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/trainers/__init__.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/utils/__init__.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/utils/match.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/torch_rechub/utils/mtl.py +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/tutorials/DIN.ipynb +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/tutorials/DeepFM.ipynb +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/tutorials/Matching.ipynb +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/tutorials/Milvus.ipynb +0 -0
- {torch_rechub-0.0.6 → torch_rechub-0.1.0}/tutorials/Multi_Task.ipynb +0 -0
|
@@ -56,7 +56,7 @@ jobs:
|
|
|
56
56
|
python-version: ${{ env.PYTHON_VERSION }}
|
|
57
57
|
|
|
58
58
|
- name: Cache pip packages
|
|
59
|
-
uses: actions/cache@
|
|
59
|
+
uses: actions/cache@v5
|
|
60
60
|
with:
|
|
61
61
|
path: ~/.cache/pip
|
|
62
62
|
key: ${{ runner.os }}-pip-lint-${{ hashFiles('pyproject.toml') }}
|
|
@@ -104,6 +104,9 @@ jobs:
|
|
|
104
104
|
matrix:
|
|
105
105
|
os: [ubuntu-latest, windows-latest, macos-latest]
|
|
106
106
|
|
|
107
|
+
env:
|
|
108
|
+
SKIP_MILVUS_TESTS: ${{ matrix.os != 'ubuntu-latest' && '1' || '0' }}
|
|
109
|
+
|
|
107
110
|
steps:
|
|
108
111
|
- name: Checkout code
|
|
109
112
|
uses: actions/checkout@v6
|
|
@@ -114,7 +117,7 @@ jobs:
|
|
|
114
117
|
python-version: '3.9'
|
|
115
118
|
|
|
116
119
|
- name: Cache pip packages
|
|
117
|
-
uses: actions/cache@
|
|
120
|
+
uses: actions/cache@v5
|
|
118
121
|
with:
|
|
119
122
|
path: |
|
|
120
123
|
~/.cache/pip
|
|
@@ -136,12 +139,36 @@ jobs:
|
|
|
136
139
|
# Install CPU-only PyTorch for faster CI
|
|
137
140
|
pip install torch --index-url ${{ env.TORCH_INDEX_URL }}
|
|
138
141
|
# Install the package with dev and onnx dependencies
|
|
139
|
-
pip install -e ".[dev,onnx]" || pip install -r requirements-dev.txt && pip install -e .
|
|
142
|
+
pip install -e ".[dev,annoy,faiss,milvus,onnx]" || pip install -r requirements-dev.txt && pip install -e .
|
|
143
|
+
|
|
144
|
+
- name: Start Milvus
|
|
145
|
+
if: matrix.os == 'ubuntu-latest'
|
|
146
|
+
run: |
|
|
147
|
+
# Download the installation script
|
|
148
|
+
curl -sfL https://raw.githubusercontent.com/milvus-io/milvus/master/scripts/standalone_embed.sh -o standalone_embed.sh
|
|
149
|
+
|
|
150
|
+
# Start the Docker container
|
|
151
|
+
bash standalone_embed.sh start
|
|
152
|
+
|
|
153
|
+
- name: Wait for Milvus
|
|
154
|
+
if: matrix.os == 'ubuntu-latest'
|
|
155
|
+
run: |
|
|
156
|
+
for i in {1..60}; do
|
|
157
|
+
if curl -fsS http://localhost:9091/healthz >/dev/null; then exit 0; fi
|
|
158
|
+
sleep 2
|
|
159
|
+
done
|
|
160
|
+
exit 1
|
|
140
161
|
|
|
141
162
|
- name: Run tests
|
|
163
|
+
if: matrix.os != 'macos-latest'
|
|
142
164
|
run: |
|
|
143
165
|
pytest -c config/pytest.ini tests/ -v
|
|
144
166
|
|
|
167
|
+
- name: Run tests (skip indexing tests)
|
|
168
|
+
if: matrix.os == 'macos-latest'
|
|
169
|
+
run: |
|
|
170
|
+
pytest -c config/pytest.ini tests/ -v --ignore=tests/test_serving.py
|
|
171
|
+
|
|
145
172
|
- name: Run tests with coverage (Ubuntu only)
|
|
146
173
|
if: matrix.os == 'ubuntu-latest'
|
|
147
174
|
run: |
|
|
@@ -155,6 +182,16 @@ jobs:
|
|
|
155
182
|
flags: unittests
|
|
156
183
|
name: codecov-umbrella
|
|
157
184
|
|
|
185
|
+
- name: Shutdown Milvus
|
|
186
|
+
if: always() && matrix.os == 'ubuntu-latest'
|
|
187
|
+
run: |
|
|
188
|
+
# Stop Milvus
|
|
189
|
+
bash standalone_embed.sh stop
|
|
190
|
+
|
|
191
|
+
# Delete Milvus data
|
|
192
|
+
bash standalone_embed.sh delete
|
|
193
|
+
|
|
194
|
+
|
|
158
195
|
# ===================================================================
|
|
159
196
|
# 依赖兼容性验证 (Python 3.10+) - 仅验证依赖安装成功
|
|
160
197
|
# (仅在 push/PR 时运行,release 时跳过)
|
|
@@ -221,7 +258,7 @@ jobs:
|
|
|
221
258
|
bandit -r torch_rechub/ -s B101,B311,B614 -x tests,docs,examples -f txt
|
|
222
259
|
|
|
223
260
|
- name: Upload security scan results
|
|
224
|
-
uses: actions/upload-artifact@
|
|
261
|
+
uses: actions/upload-artifact@v6
|
|
225
262
|
if: always()
|
|
226
263
|
with:
|
|
227
264
|
name: bandit-security-report
|
|
@@ -259,7 +296,7 @@ jobs:
|
|
|
259
296
|
twine check dist/*
|
|
260
297
|
|
|
261
298
|
- name: Upload build artifacts
|
|
262
|
-
uses: actions/upload-artifact@
|
|
299
|
+
uses: actions/upload-artifact@v6
|
|
263
300
|
with:
|
|
264
301
|
name: dist-packages
|
|
265
302
|
path: dist/
|
|
@@ -7,6 +7,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|
|
7
7
|
|
|
8
8
|
---
|
|
9
9
|
|
|
10
|
+
## [0.1.0] - 2025-12-17
|
|
11
|
+
|
|
12
|
+
<!-- Release notes generated using configuration in .github/release.yml at main -->
|
|
13
|
+
|
|
14
|
+
## What's Changed
|
|
15
|
+
### ✨ 新特性 / Features
|
|
16
|
+
* Update docs and tutorials && Add ONNX quantization utilities and enhance export by @1985312383 in https://github.com/datawhalechina/torch-rechub/pull/150
|
|
17
|
+
* REFACTOR+FEATURE: Standardize retrieval backends (ANNOY/FAISS/Milvus) by @ywuenthought in https://github.com/datawhalechina/torch-rechub/pull/151
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
**Full Changelog**: https://github.com/datawhalechina/torch-rechub/compare/v0.0.6...v0.1.0
|
|
21
|
+
|
|
22
|
+
---
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
|
|
10
26
|
## [0.0.6] - 2025-12-11
|
|
11
27
|
|
|
12
28
|
<!-- Release notes generated using configuration in .github/release.yml at main -->
|
|
@@ -150,18 +150,25 @@ def test_deepfm_forward():
|
|
|
150
150
|
```python
|
|
151
151
|
def train_model(model, data_loader, optimizer):
|
|
152
152
|
"""Train a recommendation model.
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
153
|
+
|
|
154
|
+
Parameters
|
|
155
|
+
----------
|
|
156
|
+
model : torch.nn.Module
|
|
157
|
+
Model to train.
|
|
158
|
+
data_loader : DataLoader
|
|
159
|
+
Training data loader.
|
|
160
|
+
optimizer : torch.optim.Optimizer
|
|
161
|
+
Optimizer for training.
|
|
162
|
+
|
|
163
|
+
Returns
|
|
164
|
+
-------
|
|
165
|
+
float
|
|
166
|
+
Training loss.
|
|
167
|
+
|
|
168
|
+
Examples
|
|
169
|
+
--------
|
|
170
|
+
>>> model = DeepFM(features, mlp_params)
|
|
171
|
+
>>> loss = train_model(model, train_loader, optimizer)
|
|
165
172
|
"""
|
|
166
173
|
# Implementation here
|
|
167
174
|
```
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: torch-rechub
|
|
3
|
-
Version: 0.0
|
|
3
|
+
Version: 0.1.0
|
|
4
4
|
Summary: A Pytorch Toolbox for Recommendation Models, Easy-to-use and Easy-to-extend.
|
|
5
5
|
Project-URL: Homepage, https://github.com/datawhalechina/torch-rechub
|
|
6
6
|
Project-URL: Documentation, https://www.torch-rechub.com
|
|
@@ -28,6 +28,8 @@ Requires-Dist: scikit-learn>=0.24.0
|
|
|
28
28
|
Requires-Dist: torch>=1.10.0
|
|
29
29
|
Requires-Dist: tqdm>=4.60.0
|
|
30
30
|
Requires-Dist: transformers>=4.46.3
|
|
31
|
+
Provides-Extra: annoy
|
|
32
|
+
Requires-Dist: annoy>=1.17.2; extra == 'annoy'
|
|
31
33
|
Provides-Extra: bigdata
|
|
32
34
|
Requires-Dist: pyarrow~=21.0; extra == 'bigdata'
|
|
33
35
|
Provides-Extra: dev
|
|
@@ -41,8 +43,13 @@ Requires-Dist: pytest-cov>=2.0; extra == 'dev'
|
|
|
41
43
|
Requires-Dist: pytest>=6.0; extra == 'dev'
|
|
42
44
|
Requires-Dist: toml>=0.10.2; extra == 'dev'
|
|
43
45
|
Requires-Dist: yapf==0.43.0; extra == 'dev'
|
|
46
|
+
Provides-Extra: faiss
|
|
47
|
+
Requires-Dist: faiss-cpu==1.13.0; extra == 'faiss'
|
|
48
|
+
Provides-Extra: milvus
|
|
49
|
+
Requires-Dist: pymilvus>=2.6.5; extra == 'milvus'
|
|
44
50
|
Provides-Extra: onnx
|
|
45
51
|
Requires-Dist: onnx>=1.14.0; extra == 'onnx'
|
|
52
|
+
Requires-Dist: onnxconverter-common>=1.14.0; extra == 'onnx'
|
|
46
53
|
Requires-Dist: onnxruntime>=1.14.0; extra == 'onnx'
|
|
47
54
|
Provides-Extra: tracking
|
|
48
55
|
Requires-Dist: swanlab>=0.1.0; extra == 'tracking'
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
---
|
|
2
|
+
title: Data Pipeline
|
|
3
|
+
description: Torch-RecHub data loading and preprocessing
|
|
4
|
+
---
|
|
5
|
+
|
|
6
|
+
# Data Pipeline
|
|
7
|
+
|
|
8
|
+
Torch-RecHub offers datasets, generators, and utilities for recommendation data.
|
|
9
|
+
|
|
10
|
+
## Datasets
|
|
11
|
+
|
|
12
|
+
### TorchDataset
|
|
13
|
+
Training/validation dataset with features and labels.
|
|
14
|
+
|
|
15
|
+
```python
|
|
16
|
+
from torch_rechub.utils.data import TorchDataset
|
|
17
|
+
dataset = TorchDataset(x, y)
|
|
18
|
+
```
|
|
19
|
+
|
|
20
|
+
### PredictDataset
|
|
21
|
+
Prediction-only dataset (features only).
|
|
22
|
+
|
|
23
|
+
```python
|
|
24
|
+
from torch_rechub.utils.data import PredictDataset
|
|
25
|
+
dataset = PredictDataset(x)
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
## Data Generators
|
|
29
|
+
|
|
30
|
+
### DataGenerator
|
|
31
|
+
Build dataloaders for ranking / multi-task models.
|
|
32
|
+
|
|
33
|
+
```python
|
|
34
|
+
from torch_rechub.utils.data import DataGenerator
|
|
35
|
+
|
|
36
|
+
dg = DataGenerator(x, y)
|
|
37
|
+
train_dl, val_dl, test_dl = dg.generate_dataloader(
|
|
38
|
+
split_ratio=[0.7, 0.1],
|
|
39
|
+
batch_size=256,
|
|
40
|
+
num_workers=8,
|
|
41
|
+
)
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
### MatchDataGenerator
|
|
45
|
+
Build dataloaders for matching/retrieval models.
|
|
46
|
+
|
|
47
|
+
```python
|
|
48
|
+
from torch_rechub.utils.data import MatchDataGenerator
|
|
49
|
+
|
|
50
|
+
dg = MatchDataGenerator(x, y)
|
|
51
|
+
train_dl, test_dl, item_dl = dg.generate_dataloader(
|
|
52
|
+
x_test_user=x_test_user,
|
|
53
|
+
x_all_item=x_all_item,
|
|
54
|
+
batch_size=256,
|
|
55
|
+
num_workers=8,
|
|
56
|
+
)
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
## Utilities
|
|
60
|
+
|
|
61
|
+
### get_auto_embedding_dim
|
|
62
|
+
Compute embedding dim from vocab size: ``int(floor(6 * num_classes**0.25))``.
|
|
63
|
+
|
|
64
|
+
```python
|
|
65
|
+
from torch_rechub.utils.data import get_auto_embedding_dim
|
|
66
|
+
embed_dim = get_auto_embedding_dim(vocab_size=1000)
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
### get_loss_func
|
|
70
|
+
Return default loss by task type: BCELoss for classification, MSELoss for regression.
|
|
71
|
+
|
|
72
|
+
```python
|
|
73
|
+
from torch_rechub.utils.data import get_loss_func
|
|
74
|
+
loss_fn = get_loss_func(task_type="classification")
|
|
75
|
+
```
|
|
76
|
+
|
|
77
|
+
## Typical Flow
|
|
78
|
+
|
|
79
|
+
1. Define features (Dense/Sparse/Sequence).
|
|
80
|
+
2. Load raw data.
|
|
81
|
+
3. Encode categorical features (e.g., LabelEncoder).
|
|
82
|
+
4. Process sequences (pad/truncate).
|
|
83
|
+
5. Construct samples (e.g., negative sampling).
|
|
84
|
+
6. Use DataGenerator / MatchDataGenerator to build dataloaders.
|
|
85
|
+
7. Train models with the trainers.
|
|
86
|
+
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
---
|
|
2
|
+
title: Training & Evaluation
|
|
3
|
+
description: Torch-RecHub training and evaluation
|
|
4
|
+
---
|
|
5
|
+
|
|
6
|
+
# Training & Evaluation
|
|
7
|
+
|
|
8
|
+
Torch-RecHub provides trainers for ranking, matching, multi-task, and generative models. All trainers expose a unified interface for training, evaluation, prediction, ONNX export, and optional experiment tracking/visualization.
|
|
9
|
+
|
|
10
|
+
## Experiment Tracking & Visualization
|
|
11
|
+
|
|
12
|
+
- Supports **WandB / SwanLab / TensorBoardX** as `model_logger`; you can pass a single instance or a list.
|
|
13
|
+
- Auto-logs train/validation metrics and hyperparameters: `train/loss`, `learning_rate`, `val/auc` (CTR/Match), `val/task_i_score` (MTL), `val/accuracy` (Seq).
|
|
14
|
+
- Set `model_logger=None` (default) for zero overhead when tracking is not needed.
|
|
15
|
+
|
|
16
|
+
```python
|
|
17
|
+
from torch_rechub.basic.tracking import WandbLogger, TensorBoardXLogger
|
|
18
|
+
from torch_rechub.trainers import CTRTrainer
|
|
19
|
+
|
|
20
|
+
wb = WandbLogger(project="rechub-demo", name="deepfm")
|
|
21
|
+
tb = TensorBoardXLogger(log_dir="./runs/deepfm")
|
|
22
|
+
|
|
23
|
+
trainer = CTRTrainer(model, model_logger=[wb, tb])
|
|
24
|
+
trainer.fit(train_dataloader, val_dataloader)
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
## Trainers
|
|
28
|
+
|
|
29
|
+
### CTRTrainer
|
|
30
|
+
|
|
31
|
+
Used for ranking (CTR prediction) models such as DeepFM, Wide&Deep, DCN.
|
|
32
|
+
|
|
33
|
+
```python
|
|
34
|
+
from torch_rechub.trainers import CTRTrainer
|
|
35
|
+
from torch_rechub.models.ranking import DeepFM
|
|
36
|
+
|
|
37
|
+
model = DeepFM(deep_features=deep_features, fm_features=fm_features, mlp_params={"dims": [256, 128], "dropout": 0.2})
|
|
38
|
+
|
|
39
|
+
trainer = CTRTrainer(
|
|
40
|
+
model=model,
|
|
41
|
+
optimizer_params={"lr": 0.001, "weight_decay": 0.0001},
|
|
42
|
+
n_epoch=50,
|
|
43
|
+
earlystop_patience=10,
|
|
44
|
+
device="cuda:0",
|
|
45
|
+
model_path="saved/deepfm"
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
trainer.fit(train_dataloader, val_dataloader)
|
|
49
|
+
auc = trainer.evaluate(trainer.model, test_dataloader)
|
|
50
|
+
trainer.export_onnx("deepfm.onnx")
|
|
51
|
+
trainer.visualization(save_path="deepfm_architecture.pdf")
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
**Parameters**
|
|
55
|
+
- `model`: Ranking model instance.
|
|
56
|
+
- `optimizer_fn`: Optimizer function, default `torch.optim.Adam`.
|
|
57
|
+
- `optimizer_params`: Optimizer parameters.
|
|
58
|
+
- `regularization_params`: Regularization parameters.
|
|
59
|
+
- `scheduler_fn`: Learning rate scheduler.
|
|
60
|
+
- `scheduler_params`: Scheduler parameters.
|
|
61
|
+
- `n_epoch`: Number of training epochs.
|
|
62
|
+
- `earlystop_patience`: Patience for early stopping.
|
|
63
|
+
- `device`: Training device.
|
|
64
|
+
- `gpus`: List of GPU ids.
|
|
65
|
+
- `loss_mode`: Boolean. `True` when the model returns only predictions; `False` when the model returns predictions plus auxiliary loss.
|
|
66
|
+
- `model_path`: Path to save the model.
|
|
67
|
+
|
|
68
|
+
### MatchTrainer
|
|
69
|
+
|
|
70
|
+
Used for matching/retrieval models such as DSSM, YoutubeDNN, MIND.
|
|
71
|
+
|
|
72
|
+
```python
|
|
73
|
+
from torch_rechub.trainers import MatchTrainer
|
|
74
|
+
from torch_rechub.models.matching import DSSM
|
|
75
|
+
|
|
76
|
+
model = DSSM(
|
|
77
|
+
user_features=user_features,
|
|
78
|
+
item_features=item_features,
|
|
79
|
+
temperature=0.02,
|
|
80
|
+
user_params={"dims": [256, 128, 64]},
|
|
81
|
+
item_params={"dims": [256, 128, 64]}
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
trainer = MatchTrainer(
|
|
85
|
+
model=model,
|
|
86
|
+
mode=0, # 0: point-wise, 1: pair-wise, 2: list-wise
|
|
87
|
+
optimizer_params={"lr": 0.001},
|
|
88
|
+
n_epoch=50,
|
|
89
|
+
device="cuda:0",
|
|
90
|
+
model_path="saved/dssm"
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
trainer.fit(train_dataloader)
|
|
94
|
+
trainer.export_onnx("user_tower.onnx", mode="user")
|
|
95
|
+
trainer.export_onnx("item_tower.onnx", mode="item")
|
|
96
|
+
```
|
|
97
|
+
|
|
98
|
+
**Parameters**
|
|
99
|
+
- `model`: Matching model instance.
|
|
100
|
+
- `mode`: Training mode, one of 0 (point-wise), 1 (pair-wise), 2 (list-wise).
|
|
101
|
+
- `optimizer_fn`: Optimizer function, default `torch.optim.Adam`.
|
|
102
|
+
- `optimizer_params`: Optimizer parameters.
|
|
103
|
+
- `regularization_params`: Regularization parameters.
|
|
104
|
+
- `scheduler_fn`: Learning rate scheduler.
|
|
105
|
+
- `scheduler_params`: Scheduler parameters.
|
|
106
|
+
- `n_epoch`: Number of training epochs.
|
|
107
|
+
- `earlystop_patience`: Patience for early stopping.
|
|
108
|
+
- `device`: Training device.
|
|
109
|
+
- `gpus`: List of GPU ids.
|
|
110
|
+
- `model_path`: Path to save the model.
|
|
111
|
+
|
|
112
|
+
### MTLTrainer
|
|
113
|
+
|
|
114
|
+
Used for multi-task models such as MMoE, PLE, ESMM, SharedBottom.
|
|
115
|
+
|
|
116
|
+
```python
|
|
117
|
+
from torch_rechub.trainers import MTLTrainer
|
|
118
|
+
from torch_rechub.models.multi_task import MMOE
|
|
119
|
+
|
|
120
|
+
model = MMOE(
|
|
121
|
+
features=features,
|
|
122
|
+
task_types=["classification", "classification"],
|
|
123
|
+
n_expert=8,
|
|
124
|
+
expert_params={"dims": [32,16]},
|
|
125
|
+
tower_params_list=[{"dims": [32, 16]}, {"dims": [32, 16]}]
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
trainer = MTLTrainer(
|
|
129
|
+
model=model,
|
|
130
|
+
task_types=["classification", "classification"],
|
|
131
|
+
optimizer_params={"lr": 0.001},
|
|
132
|
+
adaptive_params={"method": "uwl"},
|
|
133
|
+
n_epoch=50,
|
|
134
|
+
earlystop_taskid=0,
|
|
135
|
+
device="cuda:0",
|
|
136
|
+
model_path="saved/mmoe"
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
trainer.fit(train_dataloader, val_dataloader)
|
|
140
|
+
trainer.export_onnx("mmoe.onnx")
|
|
141
|
+
```
|
|
142
|
+
|
|
143
|
+
**Parameters**
|
|
144
|
+
- `model`: Multi-task model instance.
|
|
145
|
+
- `task_types`: List of task types (`classification`, `regression`).
|
|
146
|
+
- `optimizer_fn`: Optimizer function, default `torch.optim.Adam`.
|
|
147
|
+
- `optimizer_params`: Optimizer parameters.
|
|
148
|
+
- `regularization_params`: Regularization parameters.
|
|
149
|
+
- `scheduler_fn`: Learning rate scheduler.
|
|
150
|
+
- `scheduler_params`: Scheduler parameters.
|
|
151
|
+
- `adaptive_params`: Adaptive loss weighting parameters.
|
|
152
|
+
- `n_epoch`: Number of training epochs.
|
|
153
|
+
- `earlystop_taskid`: Task id used for early stopping.
|
|
154
|
+
- `earlystop_patience`: Patience for early stopping.
|
|
155
|
+
- `device`: Training device.
|
|
156
|
+
- `gpus`: List of GPU ids.
|
|
157
|
+
- `model_path`: Path to save the model.
|
|
158
|
+
|
|
159
|
+
## Callbacks
|
|
160
|
+
|
|
161
|
+
### EarlyStopper
|
|
162
|
+
|
|
163
|
+
Used for early stopping when validation performance no longer improves.
|
|
164
|
+
|
|
165
|
+
```python
|
|
166
|
+
from torch_rechub.basic.callback import EarlyStopper
|
|
167
|
+
|
|
168
|
+
early_stopper = EarlyStopper(patience=10)
|
|
169
|
+
|
|
170
|
+
if early_stopper.stop_training(auc, model.state_dict()):
|
|
171
|
+
print(f'validation: best auc: {early_stopper.best_auc}')
|
|
172
|
+
model.load_state_dict(early_stopper.best_weights)
|
|
173
|
+
break
|
|
174
|
+
```
|
|
175
|
+
|
|
176
|
+
**Parameters**
|
|
177
|
+
- `patience`: Number of consecutive epochs without improvement before stopping.
|
|
178
|
+
- `delta`: Minimum improvement threshold to be considered progress.
|
|
179
|
+
|
|
180
|
+
## Loss Functions
|
|
181
|
+
|
|
182
|
+
### RegularizationLoss
|
|
183
|
+
|
|
184
|
+
Supports L1 and L2 regularization.
|
|
185
|
+
|
|
186
|
+
```python
|
|
187
|
+
from torch_rechub.basic.loss_func import RegularizationLoss
|
|
188
|
+
|
|
189
|
+
reg_loss_fn = RegularizationLoss(
|
|
190
|
+
embedding_l1=0.0,
|
|
191
|
+
embedding_l2=0.0001,
|
|
192
|
+
dense_l1=0.0,
|
|
193
|
+
dense_l2=0.0001
|
|
194
|
+
)
|
|
195
|
+
```
|
|
196
|
+
|
|
197
|
+
### BPRLoss
|
|
198
|
+
|
|
199
|
+
Pairwise loss for matching models.
|
|
200
|
+
|
|
201
|
+
```python
|
|
202
|
+
from torch_rechub.basic.loss_func import BPRLoss
|
|
203
|
+
|
|
204
|
+
bpr_loss = BPRLoss()
|
|
205
|
+
loss = bpr_loss(pos_score, neg_score)
|
|
206
|
+
```
|
|
207
|
+
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
---
|
|
2
|
+
title: Feature Definitions
|
|
3
|
+
description: Torch-RecHub feature types
|
|
4
|
+
---
|
|
5
|
+
|
|
6
|
+
# Feature Definitions
|
|
7
|
+
|
|
8
|
+
Torch-RecHub provides three core feature classes for different data types.
|
|
9
|
+
|
|
10
|
+
## DenseFeature
|
|
11
|
+
|
|
12
|
+
Numeric features (e.g., age, income).
|
|
13
|
+
|
|
14
|
+
```python
|
|
15
|
+
from torch_rechub.basic.features import DenseFeature
|
|
16
|
+
|
|
17
|
+
dense_feature = DenseFeature(name="age", embed_dim=1)
|
|
18
|
+
```
|
|
19
|
+
|
|
20
|
+
Parameters: `name`, `embed_dim` (always 1).
|
|
21
|
+
|
|
22
|
+
## SparseFeature
|
|
23
|
+
|
|
24
|
+
Categorical features (e.g., city, gender).
|
|
25
|
+
|
|
26
|
+
```python
|
|
27
|
+
from torch_rechub.basic.features import SparseFeature
|
|
28
|
+
|
|
29
|
+
sparse_feature = SparseFeature(
|
|
30
|
+
name="city",
|
|
31
|
+
vocab_size=100,
|
|
32
|
+
embed_dim=16,
|
|
33
|
+
shared_with=None, # share embeddings with another feature if needed
|
|
34
|
+
)
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
Parameters: `name`, `vocab_size`, `embed_dim` (auto if None), `shared_with`, `padding_idx`, `initializer`.
|
|
38
|
+
|
|
39
|
+
## SequenceFeature
|
|
40
|
+
|
|
41
|
+
Sequence or multi-hot features (e.g., behavior history, tags).
|
|
42
|
+
|
|
43
|
+
```python
|
|
44
|
+
from torch_rechub.basic.features import SequenceFeature
|
|
45
|
+
|
|
46
|
+
sequence_feature = SequenceFeature(
|
|
47
|
+
name="user_history",
|
|
48
|
+
vocab_size=10000,
|
|
49
|
+
embed_dim=32,
|
|
50
|
+
pooling="mean", # mean, sum, concat
|
|
51
|
+
)
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
Parameters: `name`, `vocab_size`, `embed_dim` (auto if None), `pooling` (mean/sum/concat), `shared_with`, `padding_idx`, `initializer`.
|
|
55
|
+
|
|
56
|
+
## Usage Example
|
|
57
|
+
|
|
58
|
+
```python
|
|
59
|
+
from torch_rechub.basic.features import DenseFeature, SparseFeature, SequenceFeature
|
|
60
|
+
|
|
61
|
+
dense_features = [
|
|
62
|
+
DenseFeature(name="age", embed_dim=1),
|
|
63
|
+
DenseFeature(name="income", embed_dim=1),
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
sparse_features = [
|
|
67
|
+
SparseFeature(name="city", vocab_size=100, embed_dim=16),
|
|
68
|
+
SparseFeature(name="gender", vocab_size=3, embed_dim=8),
|
|
69
|
+
]
|
|
70
|
+
|
|
71
|
+
sequence_features = [
|
|
72
|
+
SequenceFeature(name="user_history", vocab_size=10000, embed_dim=32, pooling="mean"),
|
|
73
|
+
]
|
|
74
|
+
|
|
75
|
+
all_features = dense_features + sparse_features + sequence_features
|
|
76
|
+
```
|
|
77
|
+
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
---
|
|
2
|
+
title: Core Components Overview
|
|
3
|
+
description: Torch-RecHub core components overview
|
|
4
|
+
---
|
|
5
|
+
|
|
6
|
+
# Core Components Overview
|
|
7
|
+
|
|
8
|
+
Torch-RecHub is modular: features, data, models, training, and tools are separated for clarity and extensibility.
|
|
9
|
+
|
|
10
|
+
## Architecture
|
|
11
|
+
|
|
12
|
+
1) **Feature layer** – dense, sparse, and sequence feature definitions.
|
|
13
|
+
2) **Data layer** – loading, preprocessing, and dataloader generation.
|
|
14
|
+
3) **Model layer** – ranking, matching, multi-task, and generative models.
|
|
15
|
+
4) **Training layer** – unified trainers for fit/eval/predict/ONNX export.
|
|
16
|
+
5) **Tools layer** – ONNX export, visualization, callbacks, losses, etc.
|
|
17
|
+
|
|
18
|
+
## Component Relations
|
|
19
|
+
|
|
20
|
+
- Feature layer guides preprocessing in the data layer.
|
|
21
|
+
- Data generators feed the training layer.
|
|
22
|
+
- Models are consumed by trainers.
|
|
23
|
+
- Trainers call tools for export/visualization/tracking.
|
|
24
|
+
|
|
25
|
+
## Component Details
|
|
26
|
+
|
|
27
|
+
- **Feature processing**: `DenseFeature`, `SparseFeature`, `SequenceFeature`. See [Features](/en/core/features).
|
|
28
|
+
- **Data pipeline**: `TorchDataset`, `PredictDataset`, `DataGenerator`, `MatchDataGenerator`. See [Data](/en/core/data).
|
|
29
|
+
- **Training & evaluation**: `CTRTrainer`, `MatchTrainer`, `MTLTrainer` (and generative trainer variants). See [Training & Evaluation](/en/core/evaluation).
|
|
30
|
+
|