torch-rechub 0.0.1__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. torch_rechub/__init__.py +14 -0
  2. torch_rechub/basic/activation.py +3 -1
  3. torch_rechub/basic/callback.py +2 -2
  4. torch_rechub/basic/features.py +38 -8
  5. torch_rechub/basic/initializers.py +92 -0
  6. torch_rechub/basic/layers.py +800 -46
  7. torch_rechub/basic/loss_func.py +223 -0
  8. torch_rechub/basic/metaoptimizer.py +76 -0
  9. torch_rechub/basic/metric.py +251 -0
  10. torch_rechub/models/generative/__init__.py +6 -0
  11. torch_rechub/models/generative/hllm.py +249 -0
  12. torch_rechub/models/generative/hstu.py +189 -0
  13. torch_rechub/models/matching/__init__.py +13 -0
  14. torch_rechub/models/matching/comirec.py +193 -0
  15. torch_rechub/models/matching/dssm.py +72 -0
  16. torch_rechub/models/matching/dssm_facebook.py +77 -0
  17. torch_rechub/models/matching/dssm_senet.py +87 -0
  18. torch_rechub/models/matching/gru4rec.py +85 -0
  19. torch_rechub/models/matching/mind.py +103 -0
  20. torch_rechub/models/matching/narm.py +82 -0
  21. torch_rechub/models/matching/sasrec.py +143 -0
  22. torch_rechub/models/matching/sine.py +148 -0
  23. torch_rechub/models/matching/stamp.py +81 -0
  24. torch_rechub/models/matching/youtube_dnn.py +75 -0
  25. torch_rechub/models/matching/youtube_sbc.py +98 -0
  26. torch_rechub/models/multi_task/__init__.py +5 -2
  27. torch_rechub/models/multi_task/aitm.py +83 -0
  28. torch_rechub/models/multi_task/esmm.py +19 -8
  29. torch_rechub/models/multi_task/mmoe.py +18 -12
  30. torch_rechub/models/multi_task/ple.py +41 -29
  31. torch_rechub/models/multi_task/shared_bottom.py +3 -2
  32. torch_rechub/models/ranking/__init__.py +13 -2
  33. torch_rechub/models/ranking/afm.py +65 -0
  34. torch_rechub/models/ranking/autoint.py +102 -0
  35. torch_rechub/models/ranking/bst.py +61 -0
  36. torch_rechub/models/ranking/dcn.py +38 -0
  37. torch_rechub/models/ranking/dcn_v2.py +59 -0
  38. torch_rechub/models/ranking/deepffm.py +131 -0
  39. torch_rechub/models/ranking/deepfm.py +8 -7
  40. torch_rechub/models/ranking/dien.py +191 -0
  41. torch_rechub/models/ranking/din.py +31 -19
  42. torch_rechub/models/ranking/edcn.py +101 -0
  43. torch_rechub/models/ranking/fibinet.py +42 -0
  44. torch_rechub/models/ranking/widedeep.py +6 -6
  45. torch_rechub/trainers/__init__.py +4 -2
  46. torch_rechub/trainers/ctr_trainer.py +191 -0
  47. torch_rechub/trainers/match_trainer.py +239 -0
  48. torch_rechub/trainers/matching.md +3 -0
  49. torch_rechub/trainers/mtl_trainer.py +137 -23
  50. torch_rechub/trainers/seq_trainer.py +293 -0
  51. torch_rechub/utils/__init__.py +0 -0
  52. torch_rechub/utils/data.py +492 -0
  53. torch_rechub/utils/hstu_utils.py +198 -0
  54. torch_rechub/utils/match.py +457 -0
  55. torch_rechub/utils/mtl.py +136 -0
  56. torch_rechub/utils/onnx_export.py +353 -0
  57. torch_rechub-0.0.4.dist-info/METADATA +391 -0
  58. torch_rechub-0.0.4.dist-info/RECORD +62 -0
  59. {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.4.dist-info}/WHEEL +1 -2
  60. {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.4.dist-info/licenses}/LICENSE +1 -1
  61. torch_rechub/basic/utils.py +0 -168
  62. torch_rechub/trainers/trainer.py +0 -111
  63. torch_rechub-0.0.1.dist-info/METADATA +0 -105
  64. torch_rechub-0.0.1.dist-info/RECORD +0 -26
  65. torch_rechub-0.0.1.dist-info/top_level.txt +0 -1
@@ -0,0 +1,391 @@
1
+ Metadata-Version: 2.4
2
+ Name: torch-rechub
3
+ Version: 0.0.4
4
+ Summary: A Pytorch Toolbox for Recommendation Models, Easy-to-use and Easy-to-extend.
5
+ Project-URL: Homepage, https://github.com/datawhalechina/torch-rechub
6
+ Project-URL: Documentation, https://www.torch-rechub.com
7
+ Project-URL: Repository, https://github.com/datawhalechina/torch-rechub.git
8
+ Project-URL: Issues, https://github.com/datawhalechina/torch-rechub/issues
9
+ Author-email: rechub team <morningsky@tju.edu.cn>
10
+ License: MIT
11
+ License-File: LICENSE
12
+ Classifier: Development Status :: 4 - Beta
13
+ Classifier: Intended Audience :: Developers
14
+ Classifier: Intended Audience :: Science/Research
15
+ Classifier: License :: OSI Approved :: MIT License
16
+ Classifier: Operating System :: OS Independent
17
+ Classifier: Programming Language :: Python :: 3
18
+ Classifier: Programming Language :: Python :: 3.9
19
+ Classifier: Programming Language :: Python :: 3.10
20
+ Classifier: Programming Language :: Python :: 3.11
21
+ Classifier: Programming Language :: Python :: 3.12
22
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
23
+ Requires-Python: >=3.9
24
+ Requires-Dist: accelerate>=1.0.1
25
+ Requires-Dist: numpy>=1.19.0
26
+ Requires-Dist: pandas>=1.2.0
27
+ Requires-Dist: scikit-learn>=0.24.0
28
+ Requires-Dist: torch>=1.10.0
29
+ Requires-Dist: tqdm>=4.60.0
30
+ Requires-Dist: transformers>=4.46.3
31
+ Provides-Extra: dev
32
+ Requires-Dist: bandit>=1.7.0; extra == 'dev'
33
+ Requires-Dist: flake8>=3.8.0; extra == 'dev'
34
+ Requires-Dist: isort==5.13.2; extra == 'dev'
35
+ Requires-Dist: mypy>=0.800; extra == 'dev'
36
+ Requires-Dist: pre-commit>=2.20.0; extra == 'dev'
37
+ Requires-Dist: pytest-cov>=2.0; extra == 'dev'
38
+ Requires-Dist: pytest>=6.0; extra == 'dev'
39
+ Requires-Dist: toml>=0.10.2; extra == 'dev'
40
+ Requires-Dist: yapf==0.43.0; extra == 'dev'
41
+ Provides-Extra: onnx
42
+ Requires-Dist: onnx>=1.12.0; extra == 'onnx'
43
+ Requires-Dist: onnxruntime>=1.12.0; extra == 'onnx'
44
+ Description-Content-Type: text/markdown
45
+
46
+ # 🔥 Torch-RecHub - 轻量、高效、易用的 PyTorch 推荐系统框架
47
+
48
+ > 🚀 **30+ 主流推荐模型** | 🎯 **开箱即用** | 📦 **一键部署 ONNX** | 🤖 **支持生成式推荐 (HSTU/HLLM)**
49
+
50
+ [![许可证](https://img.shields.io/badge/license-MIT-blue?style=for-the-badge)](LICENSE)
51
+ ![GitHub Repo stars](https://img.shields.io/github/stars/datawhalechina/torch-rechub?style=for-the-badge)
52
+ ![GitHub forks](https://img.shields.io/github/forks/datawhalechina/torch-rechub?style=for-the-badge)
53
+ ![GitHub issues](https://img.shields.io/github/issues/datawhalechina/torch-rechub?style=for-the-badge)
54
+ [![Python 版本](https://img.shields.io/badge/python-3.9%2B-orange?style=for-the-badge)](https://www.python.org/)
55
+ [![PyTorch 版本](https://img.shields.io/badge/pytorch-1.7%2B-orange?style=for-the-badge)](https://pytorch.org/)
56
+ [![annoy 版本](https://img.shields.io/badge/annoy-1.17%2B-orange?style=for-the-badge)](https://github.com/spotify/annoy)
57
+ [![pandas 版本](https://img.shields.io/badge/pandas-1.2%2B-orange?style=for-the-badge)](https://pandas.pydata.org/)
58
+ [![numpy 版本](https://img.shields.io/badge/numpy-1.19%2B-orange?style=for-the-badge)](https://numpy.org/)
59
+ [![scikit-learn 版本](https://img.shields.io/badge/scikit_learn-0.23%2B-orange?style=for-the-badge)](https://scikit-learn.org/)
60
+ [![torch-rechub 版本](https://img.shields.io/badge/torch_rechub-0.0.3%2B-orange?style=for-the-badge)](https://pypi.org/project/torch-rechub/)
61
+
62
+ [English](README_en.md) | 简体中文
63
+
64
+ **在线文档:** https://datawhalechina.github.io/torch-rechub/ (英文)| https://datawhalechina.github.io/torch-rechub/zh/ (简体中文)
65
+
66
+ **Torch-RecHub** —— **10 行代码实现工业级推荐系统**。30+ 主流模型开箱即用,支持一键 ONNX 部署,让你专注于业务而非工程。
67
+
68
+ ![Torch-RecHub 横幅](docs/public/img/banner.png)
69
+
70
+ ## 🎯 为什么选择 Torch-RecHub?
71
+
72
+ | 特性 | Torch-RecHub | 其他框架 |
73
+ |------|-------------|---------|
74
+ | 代码行数 | **10行** 完成训练+评估+部署 | 100+ 行 |
75
+ | 模型覆盖 | **30+** 主流模型 | 有限 |
76
+ | 生成式推荐 | ✅ HSTU/HLLM (Meta 2024) | ❌ |
77
+ | ONNX 一键导出 | ✅ 内置支持 | 需手动适配 |
78
+ | 学习曲线 | 极低 | 陡峭 |
79
+
80
+ ## ✨ 特性
81
+
82
+ * **模块化设计:** 易于添加新的模型、数据集和评估指标。
83
+ * **基于 PyTorch:** 利用 PyTorch 的动态图和 GPU 加速能力。
84
+ * **丰富的模型库:** 涵盖 **30+** 经典和前沿推荐算法(召回、排序、多任务、生成式推荐等)。
85
+ * **标准化流程:** 提供统一的数据加载、训练和评估流程。
86
+ * **易于配置:** 通过配置文件或命令行参数轻松调整实验设置。
87
+ * **可复现性:** 旨在确保实验结果的可复现性。
88
+ * **ONNX 导出:** 支持将训练好的模型导出为 ONNX 格式,便于部署到生产环境。
89
+ * **其他特性:** 例如,支持负采样、多任务学习等。
90
+
91
+ ## 📖 目录
92
+
93
+ - [🔥 Torch-RecHub - 轻量、高效、易用的 PyTorch 推荐系统框架](#-torch-rechub---轻量高效易用的-pytorch-推荐系统框架)
94
+ - [🎯 为什么选择 Torch-RecHub?](#-为什么选择-torch-rechub)
95
+ - [✨ 特性](#-特性)
96
+ - [📖 目录](#-目录)
97
+ - [🔧 安装](#-安装)
98
+ - [环境要求](#环境要求)
99
+ - [安装步骤](#安装步骤)
100
+ - [🚀 快速开始](#-快速开始)
101
+ - [📂 项目结构](#-项目结构)
102
+ - [💡 支持的模型](#-支持的模型)
103
+ - [📊 支持的数据集](#-支持的数据集)
104
+ - [🧪 示例](#-示例)
105
+ - [精排(CTR预测)](#精排ctr预测)
106
+ - [多任务排序](#多任务排序)
107
+ - [召回模型](#召回模型)
108
+ - [👨‍💻‍ 贡献者](#-贡献者)
109
+ - [🤝 贡献指南](#-贡献指南)
110
+ - [📜 许可证](#-许可证)
111
+ - [📚 引用](#-引用)
112
+ - [📫 联系方式](#-联系方式)
113
+ - [⭐️ 项目 star 历史](#️-项目-star-历史)
114
+
115
+ ## 🔧 安装
116
+
117
+ ### 环境要求
118
+
119
+ * Python 3.9+
120
+ * PyTorch 1.7+ (建议使用支持 CUDA 的版本以获得 GPU 加速)
121
+ * NumPy
122
+ * Pandas
123
+ * SciPy
124
+ * Scikit-learn
125
+
126
+ ### 安装步骤
127
+
128
+ **稳定版(推荐用户使用):**
129
+ ```bash
130
+ pip install torch-rechub
131
+ ```
132
+
133
+ **最新版:**
134
+ ```bash
135
+ # 首先安装 uv(如果尚未安装)
136
+ pip install uv
137
+
138
+ # 克隆并安装
139
+ git clone https://github.com/datawhalechina/torch-rechub.git
140
+ cd torch-rechub
141
+ uv sync
142
+ ```
143
+
144
+
145
+
146
+ ## 🚀 快速开始
147
+
148
+ 以下是一个简单的示例,展示如何在 MovieLens 数据集上训练模型(例如 DSSM):
149
+
150
+ ```bash
151
+ # 克隆仓库(如果使用最新版)
152
+ git clone https://github.com/datawhalechina/torch-rechub.git
153
+ cd torch-rechub
154
+ uv sync
155
+
156
+ # 运行示例
157
+ python examples/matching/run_ml_dssm.py
158
+
159
+ # 或使用自定义参数:
160
+ python examples/matching/run_ml_dssm.py --model_name dssm --device 'cuda:0' --learning_rate 0.001 --epoch 50 --batch_size 4096 --weight_decay 0.0001 --save_dir 'saved/dssm_ml-100k'
161
+ ```
162
+
163
+ 训练完成后,模型文件将保存在 `saved/dssm_ml-100k` 目录下(或你配置的其他目录)。
164
+
165
+ ## 📂 项目结构
166
+
167
+ ```
168
+ torch-rechub/ # 根目录
169
+ ├── README.md # 项目文档
170
+ ├── pyproject.toml # 项目配置和依赖
171
+ ├── torch_rechub/ # 核心代码库
172
+ │ ├── basic/ # 基础组件
173
+ │ │ ├── activation.py # 激活函数
174
+ │ │ ├── features.py # 特征工程
175
+ │ │ ├── layers.py # 神经网络层
176
+ │ │ ├── loss_func.py # 损失函数
177
+ │ │ └── metric.py # 评估指标
178
+ │ ├── models/ # 推荐模型实现
179
+ │ │ ├── matching/ # 召回模型(DSSM/MIND/GRU4Rec等)
180
+ │ │ ├── ranking/ # 排序模型(WideDeep/DeepFM/DIN等)
181
+ │ │ └── multi_task/ # 多任务模型(MMoE/ESMM等)
182
+ │ ├── trainers/ # 训练框架
183
+ │ │ ├── ctr_trainer.py # CTR预测训练器
184
+ │ │ ├── match_trainer.py # 召回模型训练器
185
+ │ │ └── mtl_trainer.py # 多任务学习训练器
186
+ │ └── utils/ # 工具函数
187
+ │ ├── data.py # 数据处理工具
188
+ │ ├── match.py # 召回工具
189
+ │ ├── mtl.py # 多任务工具
190
+ │ └── onnx_export.py # ONNX 导出工具
191
+ ├── examples/ # 示例脚本
192
+ │ ├── matching/ # 召回任务示例
193
+ │ ├── ranking/ # 排序任务示例
194
+ │ └── generative/ # 生成式推荐示例(HSTU、HLLM 等)
195
+ ├── docs/ # 文档(VitePress,多语言)
196
+ ├── tutorials/ # Jupyter教程
197
+ ├── tests/ # 单元测试
198
+ ├── config/ # 配置文件
199
+ └── scripts/ # 工具脚本
200
+ ```
201
+
202
+ ## 💡 支持的模型
203
+
204
+ 本框架目前支持 **30+** 主流推荐模型:
205
+
206
+ ### 排序模型 (Ranking Models) - 13个
207
+
208
+ | 模型 | 论文 | 简介 |
209
+ |------|------|------|
210
+ | **DeepFM** | [IJCAI 2017](https://arxiv.org/abs/1703.04247) | FM + Deep 联合训练 |
211
+ | **Wide&Deep** | [DLRS 2016](https://arxiv.org/abs/1606.07792) | 记忆 + 泛化能力结合 |
212
+ | **DCN** | [KDD 2017](https://arxiv.org/abs/1708.05123) | 显式特征交叉网络 |
213
+ | **DCN-v2** | [WWW 2021](https://arxiv.org/abs/2008.13535) | 增强版交叉网络 |
214
+ | **DIN** | [KDD 2018](https://arxiv.org/abs/1706.06978) | 注意力机制捕捉用户兴趣 |
215
+ | **DIEN** | [AAAI 2019](https://arxiv.org/abs/1809.03672) | 兴趣演化建模 |
216
+ | **BST** | [DLP-KDD 2019](https://arxiv.org/abs/1905.06874) | Transformer 序列建模 |
217
+ | **AFM** | [IJCAI 2017](https://arxiv.org/abs/1708.04617) | 注意力因子分解机 |
218
+ | **AutoInt** | [CIKM 2019](https://arxiv.org/abs/1810.11921) | 自动特征交互学习 |
219
+ | **FiBiNET** | [RecSys 2019](https://arxiv.org/abs/1905.09433) | 特征重要性 + 双线性交互 |
220
+ | **DeepFFM** | [RecSys 2019](https://arxiv.org/abs/1611.00144) | 场感知因子分解机 |
221
+ | **EDCN** | [KDD 2021](https://arxiv.org/abs/2106.03032) | 增强型交叉网络 |
222
+
223
+ ### 召回模型 (Matching Models) - 12个
224
+
225
+ | 模型 | 论文 | 简介 |
226
+ |------|------|------|
227
+ | **DSSM** | [CIKM 2013](https://posenhuang.github.io/papers/cikm2013_DSSM_fullversion.pdf) | 经典双塔召回模型 |
228
+ | **YoutubeDNN** | [RecSys 2016](https://dl.acm.org/doi/10.1145/2959100.2959190) | YouTube 深度召回 |
229
+ | **YoutubeSBC** | [RecSys 2019](https://dl.acm.org/doi/10.1145/3298689.3346997) | 采样偏差校正版本 |
230
+ | **MIND** | [CIKM 2019](https://arxiv.org/abs/1904.08030) | 多兴趣动态路由 |
231
+ | **SINE** | [WSDM 2021](https://arxiv.org/abs/2103.06920) | 稀疏兴趣网络 |
232
+ | **GRU4Rec** | [ICLR 2016](https://arxiv.org/abs/1511.06939) | GRU 序列推荐 |
233
+ | **SASRec** | [ICDM 2018](https://arxiv.org/abs/1808.09781) | 自注意力序列推荐 |
234
+ | **NARM** | [CIKM 2017](https://arxiv.org/abs/1711.04725) | 神经注意力会话推荐 |
235
+ | **STAMP** | [KDD 2018](https://dl.acm.org/doi/10.1145/3219819.3219895) | 短期注意力记忆优先 |
236
+ | **ComiRec** | [KDD 2020](https://arxiv.org/abs/2005.09347) | 可控多兴趣推荐 |
237
+
238
+ ### 多任务模型 (Multi-Task Models) - 5个
239
+
240
+ | 模型 | 论文 | 简介 |
241
+ |------|------|------|
242
+ | **ESMM** | [SIGIR 2018](https://arxiv.org/abs/1804.07931) | 全空间多任务建模 |
243
+ | **MMoE** | [KDD 2018](https://dl.acm.org/doi/10.1145/3219819.3220007) | 多门控专家混合 |
244
+ | **PLE** | [RecSys 2020](https://dl.acm.org/doi/10.1145/3383313.3412236) | 渐进式分层提取 |
245
+ | **AITM** | [KDD 2021](https://arxiv.org/abs/2105.08489) | 自适应信息迁移 |
246
+ | **SharedBottom** | - | 经典多任务共享底层 |
247
+
248
+ ### 生成式推荐 (Generative Recommendation) - 2个
249
+
250
+ | 模型 | 论文 | 简介 |
251
+ |------|------|------|
252
+ | **HSTU** | [Meta 2024](https://arxiv.org/abs/2402.17152) | 层级序列转换单元,支撑 Meta 万亿参数推荐系统 |
253
+ | **HLLM** | [2024](https://arxiv.org/abs/2409.12740) | 层级大语言模型推荐,融合 LLM 语义理解能力 |
254
+
255
+ ## 📊 支持的数据集
256
+
257
+ 框架内置了对以下常见数据集格式的支持或提供了处理脚本:
258
+
259
+ * **MovieLens**
260
+ * **Amazon**
261
+ * **Criteo**
262
+ * **Avazu**
263
+ * **Census-Income**
264
+ * **BookCrossing**
265
+ * **Ali-ccp**
266
+ * **Yidian**
267
+ * ...
268
+
269
+ 我们期望的数据格式通常是包含以下字段的交互文件:
270
+ - 用户 ID
271
+ - 物品 ID
272
+ - 评分(可选)
273
+ - 时间戳(可选)
274
+
275
+ 具体格式要求请参考 `tutorials` 目录下的示例代码。
276
+
277
+ 你可以方便地集成你自己的数据集,只需确保它符合框架要求的数据格式,或编写自定义的数据加载器。
278
+
279
+
280
+ ## 🧪 示例
281
+
282
+ 所有模型使用案例参考 `/examples`
283
+
284
+
285
+ ### 精排(CTR预测)
286
+
287
+ ```python
288
+ from torch_rechub.models.ranking import DeepFM
289
+ from torch_rechub.trainers import CTRTrainer
290
+ from torch_rechub.utils.data import DataGenerator
291
+
292
+ dg = DataGenerator(x, y)
293
+ train_dataloader, val_dataloader, test_dataloader = dg.generate_dataloader(split_ratio=[0.7, 0.1], batch_size=256)
294
+
295
+ model = DeepFM(deep_features=deep_features, fm_features=fm_features, mlp_params={"dims": [256, 128], "dropout": 0.2, "activation": "relu"})
296
+
297
+ ctr_trainer = CTRTrainer(model)
298
+ ctr_trainer.fit(train_dataloader, val_dataloader)
299
+ auc = ctr_trainer.evaluate(ctr_trainer.model, test_dataloader)
300
+ ctr_trainer.export_onnx("deepfm.onnx")
301
+ ```
302
+
303
+ ### 多任务排序
304
+
305
+ ```python
306
+ from torch_rechub.models.multi_task import SharedBottom, ESMM, MMOE, PLE, AITM
307
+ from torch_rechub.trainers import MTLTrainer
308
+
309
+ task_types = ["classification", "classification"]
310
+ model = MMOE(features, task_types, 8, expert_params={"dims": [32,16]}, tower_params_list=[{"dims": [32, 16]}, {"dims": [32, 16]}])
311
+
312
+ mtl_trainer = MTLTrainer(model)
313
+ mtl_trainer.fit(train_dataloader, val_dataloader)
314
+ auc = ctr_trainer.evaluate(ctr_trainer.model, test_dataloader)
315
+ mtl_trainer.export_onnx("mmoe.onnx")
316
+ ```
317
+
318
+ ### 召回模型
319
+
320
+ ```python
321
+ from torch_rechub.models.matching import DSSM
322
+ from torch_rechub.trainers import MatchTrainer
323
+ from torch_rechub.utils.data import MatchDataGenerator
324
+
325
+ dg = MatchDataGenerator(x, y)
326
+ train_dl, test_dl, item_dl = dg.generate_dataloader(test_user, all_item, batch_size=256)
327
+
328
+ model = DSSM(user_features, item_features, temperature=0.02,
329
+ user_params={
330
+ "dims": [256, 128, 64],
331
+ "activation": 'prelu',
332
+ },
333
+ item_params={
334
+ "dims": [256, 128, 64],
335
+ "activation": 'prelu',
336
+ })
337
+
338
+ match_trainer = MatchTrainer(model)
339
+ match_trainer.fit(train_dl)
340
+ match_trainer.export_onnx("dssm.onnx")
341
+ # 双塔模型可分别导出用户塔和物品塔:
342
+ # match_trainer.export_onnx("user_tower.onnx", mode="user")
343
+ # match_trainer.export_onnx("dssm_item.onnx", tower="item")
344
+ ```
345
+
346
+ ## 👨‍💻‍ 贡献者
347
+
348
+ 感谢所有的贡献者!
349
+
350
+ ![GitHub contributors](https://img.shields.io/github/contributors/datawhalechina/torch-rechub?color=32A9C3&labelColor=1B3C4A&logo=contributorcovenant)
351
+
352
+ [![contributors](https://contrib.rocks/image?repo=datawhalechina/torch-rechub)](https://github.com/datawhalechina/torch-rechub/graphs/contributors)
353
+
354
+ ## 🤝 贡献指南
355
+
356
+ 我们欢迎各种形式的贡献!请查看 [CONTRIBUTING.md](CONTRIBUTING.md) 了解详细的贡献指南。
357
+
358
+ 我们也欢迎通过 [Issues](https://github.com/datawhalechina/torch-rechub/issues) 报告 Bug 或提出功能建议。
359
+
360
+ ## 📜 许可证
361
+
362
+ 本项目采用 [MIT 许可证](LICENSE)。
363
+
364
+ ## 📚 引用
365
+
366
+ 如果你在研究或工作中使用了本框架,请考虑引用:
367
+
368
+ ```bibtex
369
+ @misc{torch_rechub,
370
+ title = {Torch-RecHub},
371
+ author = {Datawhale},
372
+ year = {2022},
373
+ publisher = {GitHub},
374
+ journal = {GitHub repository},
375
+ howpublished = {\url{https://github.com/datawhalechina/torch-rechub}},
376
+ note = {A PyTorch-based recommender system framework providing easy-to-use and extensible solutions}
377
+ }
378
+ ```
379
+
380
+ ## 📫 联系方式
381
+
382
+ * **项目负责人:** [1985312383](https://github.com/1985312383)
383
+ * [**GitHub Disscussions**](https://github.com/datawhalechina/torch-rechub/discussions)
384
+
385
+ ## ⭐️ 项目 star 历史
386
+
387
+ [![Star History Chart](https://api.star-history.com/svg?repos=datawhalechina/torch-rechub&type=Date)](https://www.star-history.com/#datawhalechina/torch-rechub&Date)
388
+
389
+ ---
390
+
391
+ *最后更新: [2025-12-04]*
@@ -0,0 +1,62 @@
1
+ torch_rechub/__init__.py,sha256=XUwV85oz-uIokuE9qj3nmbUQg3EY8dZcDMohlob3suw,245
2
+ torch_rechub/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ torch_rechub/basic/activation.py,sha256=hIZDCe7cAgV3bX2UnvUrkO8pQs4iXxkQGD0J4GejbVg,1600
4
+ torch_rechub/basic/callback.py,sha256=ZeiDSDQAZUKmyK1AyGJCnqEJ66vwfwlX5lOyu6-h2G0,946
5
+ torch_rechub/basic/features.py,sha256=TLHR5EaNvIbKyKd730Qt8OlLpV0Km91nv2TMnq0HObk,3562
6
+ torch_rechub/basic/initializers.py,sha256=V6hprXvRexcw3vrYsf8Qp-F52fp8uzPMpa1CvkHofy8,3196
7
+ torch_rechub/basic/layers.py,sha256=URWk78dlffMOAhDVDhOhugcr4nmwEa192AI1diktC-4,39653
8
+ torch_rechub/basic/loss_func.py,sha256=6bjljqpiuUP6O8-wUbGd8FSvflY5Dp_DV_57OuQVMz4,7969
9
+ torch_rechub/basic/metaoptimizer.py,sha256=y-oT4MV3vXnSQ5Zd_ZEHP1KClITEi3kbZa6RKjlkYw8,3093
10
+ torch_rechub/basic/metric.py,sha256=9JsaJJGvT6VRvsLoM2Y171CZxESsjYTofD3qnMI-bPM,8443
11
+ torch_rechub/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
+ torch_rechub/models/generative/__init__.py,sha256=TsCdVIhOcalQwqKZKjEuNbHKyIjyclapKGNwYfFR7TM,135
13
+ torch_rechub/models/generative/hllm.py,sha256=6Vrp5Bh0fTFHCn7C-3EqzOyc7UunOyEY9TzAKGHrW-8,9669
14
+ torch_rechub/models/generative/hstu.py,sha256=M1ByAWHxrkvmwaXPNdGFrbAQYyYJswnABTY3jCEceyg,7846
15
+ torch_rechub/models/matching/__init__.py,sha256=fjWOzJB8loPGy8rJMG-6G-NUISp7k3sD_1FdsKGw1as,471
16
+ torch_rechub/models/matching/comirec.py,sha256=8KB5rg2FWlZaG73CBI7_J8-J-XpjTPnblwh5OsbtAbc,9439
17
+ torch_rechub/models/matching/dssm.py,sha256=1Q0JYpt1h_7NWlLN5a_RbCoUSubZwpYTVEXccSn44eg,3003
18
+ torch_rechub/models/matching/dssm_facebook.py,sha256=n3MS7FT_kyJSDnVTlPCv_nPJ0MHCtMgJRUPDRh7jBLM,3508
19
+ torch_rechub/models/matching/dssm_senet.py,sha256=_E-xEh44XvOaBHP8XdSRkFsTvajhovxlYyCt3H9P61c,4052
20
+ torch_rechub/models/matching/gru4rec.py,sha256=cJtYCkFyg3cPYkOy_YeXRAsTev0cBPiicrj68xJup9k,3932
21
+ torch_rechub/models/matching/mind.py,sha256=NIUeqWhrnZeiFDMNFvXfMx1GMBMaCZnc6nxNZCJpwSE,4933
22
+ torch_rechub/models/matching/narm.py,sha256=2dlTuan9AFrku53WJlBbTwgLlfOHsas3-JBFGxEz7oE,3167
23
+ torch_rechub/models/matching/sasrec.py,sha256=QDfKrFl-aduWg6rY3R13RrdpMiApVugDmtEsWJulgzg,5534
24
+ torch_rechub/models/matching/sine.py,sha256=sUTUHbnewdSBd51epDIp9j-B1guKkhm6eM-KkZ3oS3Q,6746
25
+ torch_rechub/models/matching/stamp.py,sha256=DBVM3iCoQTBKwO7oKHg5SCCDXqTuRJ4Ko1n7StgEovA,3308
26
+ torch_rechub/models/matching/youtube_dnn.py,sha256=EQV_GoEs2Hxwg1U3Dj7-lWkEejEqGmtZ7D9CgfknQdA,3368
27
+ torch_rechub/models/matching/youtube_sbc.py,sha256=paw9uRnbNw_-EaFpRogy7rB4vhw4KN0Qf8BfQylTj4I,4757
28
+ torch_rechub/models/multi_task/__init__.py,sha256=5N8aJ32fzxniDm4d-AeNSi81CFWyBhjoSaK3OC-XCkY,189
29
+ torch_rechub/models/multi_task/aitm.py,sha256=hlG4opauSmM4SNZBPqogPZKNCPheGCVx--JSEWXeIJ8,3355
30
+ torch_rechub/models/multi_task/esmm.py,sha256=y6Gv_mWRw7lcZm-wjw2OHJVwvDHFfvkVSnJvQGO6kUk,2742
31
+ torch_rechub/models/multi_task/mmoe.py,sha256=yprs9P5vL6C4mf3lHb0uyjyJisYVny5UZ6Q9MGVfn-0,2831
32
+ torch_rechub/models/multi_task/ple.py,sha256=5QieTfL-jhHMzuSOlmgXc0dGARJzpLT2sHHdMedjtpM,6448
33
+ torch_rechub/models/multi_task/shared_bottom.py,sha256=XH32YOGsb8M97ggcm6RPY2k3T5g9GgeZsJ2NjwpZoko,1866
34
+ torch_rechub/models/ranking/__init__.py,sha256=VWhbC4918N1hDiUIj7shHp-N7oWOYYpLc4siBtMUV0w,447
35
+ torch_rechub/models/ranking/afm.py,sha256=VbTa3HSyYmpTMdSOr54ex6HOn_uA9tYtmlwX7IKDhUs,2267
36
+ torch_rechub/models/ranking/autoint.py,sha256=VQ8BrPYaSXZzDcSvQ0spcxl5rSqNL-gqemH7UKiLUDk,3880
37
+ torch_rechub/models/ranking/bst.py,sha256=b9cZUnFV51Fw4UAOuPJpog9pHMb2cuGr-yNAvqK7Xpg,3632
38
+ torch_rechub/models/ranking/dcn.py,sha256=ErWAF9MlKsZj_aF1DQQa1cHkobMyo7J5CTreUptRzhw,1302
39
+ torch_rechub/models/ranking/dcn_v2.py,sha256=09mx4LP8D1CVCpFZpJmjCWUm6OzJ9lXgyAvSyaaxzQo,2832
40
+ torch_rechub/models/ranking/deepffm.py,sha256=XYYnpwsJ028fyNcH1MWGPWQVrlonO5aiZO7bpC25PnE,6044
41
+ torch_rechub/models/ranking/deepfm.py,sha256=5yKLrdLPuD3NigGL3bKnG5HS3kqCuRw7gAIkXl6qY9Q,1783
42
+ torch_rechub/models/ranking/dien.py,sha256=2jaPluJf0K_ctHM3MTJ042bXxGLI8KphOTkKVScKUAg,8817
43
+ torch_rechub/models/ranking/din.py,sha256=HsOCEErea3KwEiyWw4M_aX_LMC_-Sqs1C_zeRLKLV_c,4542
44
+ torch_rechub/models/ranking/edcn.py,sha256=6f_S8I6Ir16kCIU54R4EfumWfUFOND5KDKUPHMgsVU0,4997
45
+ torch_rechub/models/ranking/fibinet.py,sha256=fmEJ9WkO8Mn0RtK_8aRHlnQFh_jMBPO0zODoHZPWmDA,2234
46
+ torch_rechub/models/ranking/widedeep.py,sha256=eciRvWRBHLlctabLLS5NB7k3MnqrWXCBdpflOU6jMB0,1636
47
+ torch_rechub/trainers/__init__.py,sha256=NSa2DqgfE1HGDyj40YgrbtUrfBHBxNBpw57XtaAB_jE,148
48
+ torch_rechub/trainers/ctr_trainer.py,sha256=RDUXkn7GwLzs3f0kWZwGDNCpqiMeGXo7R6ezFeZdPg8,9075
49
+ torch_rechub/trainers/match_trainer.py,sha256=xox5eaPKjSgErJQpbSr29sbyGs1p2sFaKEjxACE6uMI,11276
50
+ torch_rechub/trainers/matching.md,sha256=vIBQ3UMmVpUpyk38rrkelFwm_wXVXqMOuqzYZ4M8bzw,30
51
+ torch_rechub/trainers/mtl_trainer.py,sha256=tC4c2KIc-H8Wvj4qCzcW6TyfMLRPJyfQvTaN0dDePFg,12598
52
+ torch_rechub/trainers/seq_trainer.py,sha256=lXKRx7XbZ3iJuqp_f05vw_jkn8X5j8HmH6Nr-typiIU,12043
53
+ torch_rechub/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
54
+ torch_rechub/utils/data.py,sha256=vzLAAVt6dujg_vbGhQewiJc0l6JzwzdcM_9EjoOz898,19882
55
+ torch_rechub/utils/hstu_utils.py,sha256=qLON_pJDC-kDyQn1PoN_HaHi5xTNCwZPgJeV51Z61Lc,6207
56
+ torch_rechub/utils/match.py,sha256=l9qDwJGHPP9gOQTMYoqGVdWrlhDx1F1-8UnQwDWrEyk,18143
57
+ torch_rechub/utils/mtl.py,sha256=AxU05ezizCuLdbPuCg1ZXE0WAStzuxaS5Sc3nwMCBpI,5737
58
+ torch_rechub/utils/onnx_export.py,sha256=uRcAD4uZ3eIQbM-DPhdc0bkaPaslNsOYny6BOeLVBfU,13660
59
+ torch_rechub-0.0.4.dist-info/METADATA,sha256=SNm71v_YOfculnc13p266bD_8yLo0U_16F_aJQPDvYo,16149
60
+ torch_rechub-0.0.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
61
+ torch_rechub-0.0.4.dist-info/licenses/LICENSE,sha256=V7ietiX9G_84HtgEbxDgxClniqXGm2t5q8WM4AHGTu0,1066
62
+ torch_rechub-0.0.4.dist-info/RECORD,,
@@ -1,5 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.37.1)
2
+ Generator: hatchling 1.27.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
-
@@ -1,6 +1,6 @@
1
1
  MIT License
2
2
 
3
- Copyright (c) 2022 Mincai Lai
3
+ Copyright (c) 2022 Datawhale
4
4
 
5
5
  Permission is hereby granted, free of charge, to any person obtaining a copy
6
6
  of this software and associated documentation files (the "Software"), to deal
@@ -1,168 +0,0 @@
1
- import random
2
- import torch
3
- import numpy as np
4
- import pandas as pd
5
- from torch.utils.data import Dataset, DataLoader, random_split
6
- from sklearn.metrics import roc_auc_score, mean_squared_error
7
-
8
-
9
- class TorchDataset(Dataset):
10
-
11
- def __init__(self, x, y):
12
- super(TorchDataset, self).__init__()
13
- self.x = x
14
- self.y = y
15
-
16
- def __getitem__(self, index):
17
- return {k: v[index] for k, v in self.x.items()}, self.y[index]
18
-
19
- def __len__(self):
20
- return len(self.y)
21
-
22
-
23
- class DataGenerator(object):
24
-
25
- def __init__(self, x, y):
26
- super(DataGenerator, self).__init__()
27
- self.dataset = TorchDataset(x, y)
28
- self.length = len(self.dataset)
29
-
30
- def generate_dataloader(self, x_val=None, y_val=None, x_test=None, y_test=None, split_ratio=None, batch_size=16, num_workers=8):
31
- if split_ratio != None:
32
- train_length = int(self.length * split_ratio[0])
33
- val_length = int(self.length * split_ratio[1])
34
- test_length = self.length - train_length - val_length
35
- print("the samples of train : val : test are %d : %d : %d" % (train_length, val_length, test_length))
36
- train_dataset, val_dataset, test_dataset = random_split(self.dataset, (train_length, val_length, test_length))
37
- else:
38
- train_dataset = self.dataset
39
- val_dataset = TorchDataset(x_val, y_val)
40
- test_dataset = TorchDataset(x_test, y_test)
41
-
42
- train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers)
43
- val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers)
44
- test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers)
45
- return train_dataloader, val_dataloader, test_dataloader
46
-
47
-
48
- class PredictDataset(Dataset):
49
-
50
- def __init__(self, x):
51
- super(TorchDataset, self).__init__()
52
- self.x = x
53
-
54
- def __getitem__(self, index):
55
- return {k: v[index] for k, v in self.x.items()}
56
-
57
- def __len__(self):
58
- return len(self.x[list(self.x.keys())[0]])
59
-
60
-
61
- def get_auto_embedding_dim(num_classes):
62
- """ Calculate the dim of embedding vector according to number of classes in the category
63
- emb_dim = [6 * (num_classes)^(1/4)]
64
- reference: Deep & Cross Network for Ad Click Predictions.(ADKDD'17)
65
-
66
- Args:
67
- num_classes: number of classes in the category
68
-
69
- Returns:
70
- the dim of embedding vector
71
- """
72
- return np.floor(6 * np.pow(num_classes, 0.26))
73
-
74
-
75
- def get_loss_func(task_type="classification"):
76
- if task_type == "classification":
77
- return torch.nn.BCELoss()
78
- elif task_type == "regression":
79
- return torch.nn.MSELoss()
80
- else:
81
- raise ValueError("task_type must be classification or regression")
82
-
83
-
84
- def get_metric_func(task_type="classification"):
85
- if task_type == "classification":
86
- return roc_auc_score
87
- elif task_type == "regression":
88
- return mean_squared_error
89
- else:
90
- raise ValueError("task_type must be classification or regression")
91
-
92
-
93
- def create_seq_features(data, max_len=50, drop_short=3, shuffle=True):
94
- """Build a sequence of user's history by time.
95
-
96
- Args:
97
- data (pd.DataFrame): must contain keys: `user_id, item_id, cate_id, time`.
98
- max_len (int): the max length of a user history sequence.
99
- drop_short (int): remove some inactive user who's sequence length < drop_short.
100
- shuffle (bool): shuffle data if true.
101
-
102
- Returns:
103
- train (pd.DataFrame): target item will be each item before last two items.
104
- val (pd.DataFrame): target item is the second to last item of user's history sequence.
105
- test (pd.DataFrame): target item is the last item of user's history sequence.
106
- """
107
- n_users, n_items, n_cates = data["user_id"].max(), data["item_id"].max(), data["cate_id"].max()
108
- # 0 to be used as the symbol for padding
109
- data = data.astype('int32')
110
- data['item_id'] = data['item_id'].apply(lambda x: x + 1)
111
- data['cate_id'] = data['cate_id'].apply(lambda x: x + 1)
112
-
113
- item_cate_map = data[['item_id', 'cate_id']]
114
- item2cate_dict = item_cate_map.set_index(['item_id'])['cate_id'].to_dict()
115
-
116
- data = data.sort_values(['user_id', 'time']).groupby('user_id').agg(click_hist_list=('item_id', list), cate_hist_hist=('cate_id', list)).reset_index()
117
-
118
- # Sliding window to construct negative samples
119
- train_data, val_data, test_data = [], [], []
120
- for item in data.itertuples():
121
- if len(item[2]) < drop_short:
122
- continue
123
- click_hist_list = item[2][:max_len]
124
- cate_hist_list = item[3][:max_len]
125
-
126
- def neg_sample():
127
- neg = click_hist_list[0]
128
- while neg in click_hist_list:
129
- neg = random.randint(1, n_items)
130
- return neg
131
-
132
- neg_list = [neg_sample() for _ in range(len(click_hist_list))]
133
- hist_list = []
134
- cate_list = []
135
- for i in range(1, len(click_hist_list)):
136
- hist_list.append(click_hist_list[i - 1])
137
- cate_list.append(cate_hist_list[i - 1])
138
- hist_list_pad = hist_list + [0] * (max_len - len(hist_list))
139
- cate_list_pad = cate_list + [0] * (max_len - len(cate_list))
140
- if i == len(click_hist_list) - 1:
141
- test_data.append([hist_list_pad, cate_list_pad, click_hist_list[i], cate_hist_list[i], 1])
142
- test_data.append([hist_list_pad, cate_list_pad, neg_list[i], item2cate_dict[neg_list[i]], 0])
143
- if i == len(click_hist_list) - 2:
144
- val_data.append([hist_list_pad, cate_list_pad, click_hist_list[i], cate_hist_list[i], 1])
145
- val_data.append([hist_list_pad, cate_list_pad, neg_list[i], item2cate_dict[neg_list[i]], 0])
146
- else:
147
- train_data.append([hist_list_pad, cate_list_pad, click_hist_list[i], cate_hist_list[i], 1])
148
- train_data.append([hist_list_pad, cate_list_pad, neg_list[i], item2cate_dict[neg_list[i]], 0])
149
-
150
- # shuffle
151
- if shuffle:
152
- random.shuffle(train_data)
153
- random.shuffle(val_data)
154
- random.shuffle(test_data)
155
-
156
- col_name = ['history_item', 'history_cate', 'target_item', 'target_cate', 'label']
157
- train = pd.DataFrame(train_data, columns=col_name)
158
- val = pd.DataFrame(val_data, columns=col_name)
159
- test = pd.DataFrame(test_data, columns=col_name)
160
-
161
- return train, val, test
162
-
163
-
164
- def df_to_input_dict(data):
165
- data_dict = data.to_dict('list')
166
- for key in data.keys():
167
- data_dict[key] = np.array(data_dict[key])
168
- return data_dict