torch-rechub 0.0.1__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/__init__.py +14 -0
- torch_rechub/basic/activation.py +3 -1
- torch_rechub/basic/callback.py +2 -2
- torch_rechub/basic/features.py +38 -8
- torch_rechub/basic/initializers.py +92 -0
- torch_rechub/basic/layers.py +800 -46
- torch_rechub/basic/loss_func.py +223 -0
- torch_rechub/basic/metaoptimizer.py +76 -0
- torch_rechub/basic/metric.py +251 -0
- torch_rechub/models/generative/__init__.py +6 -0
- torch_rechub/models/generative/hllm.py +249 -0
- torch_rechub/models/generative/hstu.py +189 -0
- torch_rechub/models/matching/__init__.py +13 -0
- torch_rechub/models/matching/comirec.py +193 -0
- torch_rechub/models/matching/dssm.py +72 -0
- torch_rechub/models/matching/dssm_facebook.py +77 -0
- torch_rechub/models/matching/dssm_senet.py +87 -0
- torch_rechub/models/matching/gru4rec.py +85 -0
- torch_rechub/models/matching/mind.py +103 -0
- torch_rechub/models/matching/narm.py +82 -0
- torch_rechub/models/matching/sasrec.py +143 -0
- torch_rechub/models/matching/sine.py +148 -0
- torch_rechub/models/matching/stamp.py +81 -0
- torch_rechub/models/matching/youtube_dnn.py +75 -0
- torch_rechub/models/matching/youtube_sbc.py +98 -0
- torch_rechub/models/multi_task/__init__.py +5 -2
- torch_rechub/models/multi_task/aitm.py +83 -0
- torch_rechub/models/multi_task/esmm.py +19 -8
- torch_rechub/models/multi_task/mmoe.py +18 -12
- torch_rechub/models/multi_task/ple.py +41 -29
- torch_rechub/models/multi_task/shared_bottom.py +3 -2
- torch_rechub/models/ranking/__init__.py +13 -2
- torch_rechub/models/ranking/afm.py +65 -0
- torch_rechub/models/ranking/autoint.py +102 -0
- torch_rechub/models/ranking/bst.py +61 -0
- torch_rechub/models/ranking/dcn.py +38 -0
- torch_rechub/models/ranking/dcn_v2.py +59 -0
- torch_rechub/models/ranking/deepffm.py +131 -0
- torch_rechub/models/ranking/deepfm.py +8 -7
- torch_rechub/models/ranking/dien.py +191 -0
- torch_rechub/models/ranking/din.py +31 -19
- torch_rechub/models/ranking/edcn.py +101 -0
- torch_rechub/models/ranking/fibinet.py +42 -0
- torch_rechub/models/ranking/widedeep.py +6 -6
- torch_rechub/trainers/__init__.py +4 -2
- torch_rechub/trainers/ctr_trainer.py +191 -0
- torch_rechub/trainers/match_trainer.py +239 -0
- torch_rechub/trainers/matching.md +3 -0
- torch_rechub/trainers/mtl_trainer.py +137 -23
- torch_rechub/trainers/seq_trainer.py +293 -0
- torch_rechub/utils/__init__.py +0 -0
- torch_rechub/utils/data.py +492 -0
- torch_rechub/utils/hstu_utils.py +198 -0
- torch_rechub/utils/match.py +457 -0
- torch_rechub/utils/mtl.py +136 -0
- torch_rechub/utils/onnx_export.py +353 -0
- torch_rechub-0.0.4.dist-info/METADATA +391 -0
- torch_rechub-0.0.4.dist-info/RECORD +62 -0
- {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.4.dist-info}/WHEEL +1 -2
- {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.4.dist-info/licenses}/LICENSE +1 -1
- torch_rechub/basic/utils.py +0 -168
- torch_rechub/trainers/trainer.py +0 -111
- torch_rechub-0.0.1.dist-info/METADATA +0 -105
- torch_rechub-0.0.1.dist-info/RECORD +0 -26
- torch_rechub-0.0.1.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,391 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: torch-rechub
|
|
3
|
+
Version: 0.0.4
|
|
4
|
+
Summary: A Pytorch Toolbox for Recommendation Models, Easy-to-use and Easy-to-extend.
|
|
5
|
+
Project-URL: Homepage, https://github.com/datawhalechina/torch-rechub
|
|
6
|
+
Project-URL: Documentation, https://www.torch-rechub.com
|
|
7
|
+
Project-URL: Repository, https://github.com/datawhalechina/torch-rechub.git
|
|
8
|
+
Project-URL: Issues, https://github.com/datawhalechina/torch-rechub/issues
|
|
9
|
+
Author-email: rechub team <morningsky@tju.edu.cn>
|
|
10
|
+
License: MIT
|
|
11
|
+
License-File: LICENSE
|
|
12
|
+
Classifier: Development Status :: 4 - Beta
|
|
13
|
+
Classifier: Intended Audience :: Developers
|
|
14
|
+
Classifier: Intended Audience :: Science/Research
|
|
15
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
16
|
+
Classifier: Operating System :: OS Independent
|
|
17
|
+
Classifier: Programming Language :: Python :: 3
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
20
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
21
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
22
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
23
|
+
Requires-Python: >=3.9
|
|
24
|
+
Requires-Dist: accelerate>=1.0.1
|
|
25
|
+
Requires-Dist: numpy>=1.19.0
|
|
26
|
+
Requires-Dist: pandas>=1.2.0
|
|
27
|
+
Requires-Dist: scikit-learn>=0.24.0
|
|
28
|
+
Requires-Dist: torch>=1.10.0
|
|
29
|
+
Requires-Dist: tqdm>=4.60.0
|
|
30
|
+
Requires-Dist: transformers>=4.46.3
|
|
31
|
+
Provides-Extra: dev
|
|
32
|
+
Requires-Dist: bandit>=1.7.0; extra == 'dev'
|
|
33
|
+
Requires-Dist: flake8>=3.8.0; extra == 'dev'
|
|
34
|
+
Requires-Dist: isort==5.13.2; extra == 'dev'
|
|
35
|
+
Requires-Dist: mypy>=0.800; extra == 'dev'
|
|
36
|
+
Requires-Dist: pre-commit>=2.20.0; extra == 'dev'
|
|
37
|
+
Requires-Dist: pytest-cov>=2.0; extra == 'dev'
|
|
38
|
+
Requires-Dist: pytest>=6.0; extra == 'dev'
|
|
39
|
+
Requires-Dist: toml>=0.10.2; extra == 'dev'
|
|
40
|
+
Requires-Dist: yapf==0.43.0; extra == 'dev'
|
|
41
|
+
Provides-Extra: onnx
|
|
42
|
+
Requires-Dist: onnx>=1.12.0; extra == 'onnx'
|
|
43
|
+
Requires-Dist: onnxruntime>=1.12.0; extra == 'onnx'
|
|
44
|
+
Description-Content-Type: text/markdown
|
|
45
|
+
|
|
46
|
+
# 🔥 Torch-RecHub - 轻量、高效、易用的 PyTorch 推荐系统框架
|
|
47
|
+
|
|
48
|
+
> 🚀 **30+ 主流推荐模型** | 🎯 **开箱即用** | 📦 **一键部署 ONNX** | 🤖 **支持生成式推荐 (HSTU/HLLM)**
|
|
49
|
+
|
|
50
|
+
[](LICENSE)
|
|
51
|
+

|
|
52
|
+

|
|
53
|
+

|
|
54
|
+
[](https://www.python.org/)
|
|
55
|
+
[](https://pytorch.org/)
|
|
56
|
+
[](https://github.com/spotify/annoy)
|
|
57
|
+
[](https://pandas.pydata.org/)
|
|
58
|
+
[](https://numpy.org/)
|
|
59
|
+
[](https://scikit-learn.org/)
|
|
60
|
+
[](https://pypi.org/project/torch-rechub/)
|
|
61
|
+
|
|
62
|
+
[English](README_en.md) | 简体中文
|
|
63
|
+
|
|
64
|
+
**在线文档:** https://datawhalechina.github.io/torch-rechub/ (英文)| https://datawhalechina.github.io/torch-rechub/zh/ (简体中文)
|
|
65
|
+
|
|
66
|
+
**Torch-RecHub** —— **10 行代码实现工业级推荐系统**。30+ 主流模型开箱即用,支持一键 ONNX 部署,让你专注于业务而非工程。
|
|
67
|
+
|
|
68
|
+

|
|
69
|
+
|
|
70
|
+
## 🎯 为什么选择 Torch-RecHub?
|
|
71
|
+
|
|
72
|
+
| 特性 | Torch-RecHub | 其他框架 |
|
|
73
|
+
|------|-------------|---------|
|
|
74
|
+
| 代码行数 | **10行** 完成训练+评估+部署 | 100+ 行 |
|
|
75
|
+
| 模型覆盖 | **30+** 主流模型 | 有限 |
|
|
76
|
+
| 生成式推荐 | ✅ HSTU/HLLM (Meta 2024) | ❌ |
|
|
77
|
+
| ONNX 一键导出 | ✅ 内置支持 | 需手动适配 |
|
|
78
|
+
| 学习曲线 | 极低 | 陡峭 |
|
|
79
|
+
|
|
80
|
+
## ✨ 特性
|
|
81
|
+
|
|
82
|
+
* **模块化设计:** 易于添加新的模型、数据集和评估指标。
|
|
83
|
+
* **基于 PyTorch:** 利用 PyTorch 的动态图和 GPU 加速能力。
|
|
84
|
+
* **丰富的模型库:** 涵盖 **30+** 经典和前沿推荐算法(召回、排序、多任务、生成式推荐等)。
|
|
85
|
+
* **标准化流程:** 提供统一的数据加载、训练和评估流程。
|
|
86
|
+
* **易于配置:** 通过配置文件或命令行参数轻松调整实验设置。
|
|
87
|
+
* **可复现性:** 旨在确保实验结果的可复现性。
|
|
88
|
+
* **ONNX 导出:** 支持将训练好的模型导出为 ONNX 格式,便于部署到生产环境。
|
|
89
|
+
* **其他特性:** 例如,支持负采样、多任务学习等。
|
|
90
|
+
|
|
91
|
+
## 📖 目录
|
|
92
|
+
|
|
93
|
+
- [🔥 Torch-RecHub - 轻量、高效、易用的 PyTorch 推荐系统框架](#-torch-rechub---轻量高效易用的-pytorch-推荐系统框架)
|
|
94
|
+
- [🎯 为什么选择 Torch-RecHub?](#-为什么选择-torch-rechub)
|
|
95
|
+
- [✨ 特性](#-特性)
|
|
96
|
+
- [📖 目录](#-目录)
|
|
97
|
+
- [🔧 安装](#-安装)
|
|
98
|
+
- [环境要求](#环境要求)
|
|
99
|
+
- [安装步骤](#安装步骤)
|
|
100
|
+
- [🚀 快速开始](#-快速开始)
|
|
101
|
+
- [📂 项目结构](#-项目结构)
|
|
102
|
+
- [💡 支持的模型](#-支持的模型)
|
|
103
|
+
- [📊 支持的数据集](#-支持的数据集)
|
|
104
|
+
- [🧪 示例](#-示例)
|
|
105
|
+
- [精排(CTR预测)](#精排ctr预测)
|
|
106
|
+
- [多任务排序](#多任务排序)
|
|
107
|
+
- [召回模型](#召回模型)
|
|
108
|
+
- [👨💻 贡献者](#-贡献者)
|
|
109
|
+
- [🤝 贡献指南](#-贡献指南)
|
|
110
|
+
- [📜 许可证](#-许可证)
|
|
111
|
+
- [📚 引用](#-引用)
|
|
112
|
+
- [📫 联系方式](#-联系方式)
|
|
113
|
+
- [⭐️ 项目 star 历史](#️-项目-star-历史)
|
|
114
|
+
|
|
115
|
+
## 🔧 安装
|
|
116
|
+
|
|
117
|
+
### 环境要求
|
|
118
|
+
|
|
119
|
+
* Python 3.9+
|
|
120
|
+
* PyTorch 1.7+ (建议使用支持 CUDA 的版本以获得 GPU 加速)
|
|
121
|
+
* NumPy
|
|
122
|
+
* Pandas
|
|
123
|
+
* SciPy
|
|
124
|
+
* Scikit-learn
|
|
125
|
+
|
|
126
|
+
### 安装步骤
|
|
127
|
+
|
|
128
|
+
**稳定版(推荐用户使用):**
|
|
129
|
+
```bash
|
|
130
|
+
pip install torch-rechub
|
|
131
|
+
```
|
|
132
|
+
|
|
133
|
+
**最新版:**
|
|
134
|
+
```bash
|
|
135
|
+
# 首先安装 uv(如果尚未安装)
|
|
136
|
+
pip install uv
|
|
137
|
+
|
|
138
|
+
# 克隆并安装
|
|
139
|
+
git clone https://github.com/datawhalechina/torch-rechub.git
|
|
140
|
+
cd torch-rechub
|
|
141
|
+
uv sync
|
|
142
|
+
```
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
## 🚀 快速开始
|
|
147
|
+
|
|
148
|
+
以下是一个简单的示例,展示如何在 MovieLens 数据集上训练模型(例如 DSSM):
|
|
149
|
+
|
|
150
|
+
```bash
|
|
151
|
+
# 克隆仓库(如果使用最新版)
|
|
152
|
+
git clone https://github.com/datawhalechina/torch-rechub.git
|
|
153
|
+
cd torch-rechub
|
|
154
|
+
uv sync
|
|
155
|
+
|
|
156
|
+
# 运行示例
|
|
157
|
+
python examples/matching/run_ml_dssm.py
|
|
158
|
+
|
|
159
|
+
# 或使用自定义参数:
|
|
160
|
+
python examples/matching/run_ml_dssm.py --model_name dssm --device 'cuda:0' --learning_rate 0.001 --epoch 50 --batch_size 4096 --weight_decay 0.0001 --save_dir 'saved/dssm_ml-100k'
|
|
161
|
+
```
|
|
162
|
+
|
|
163
|
+
训练完成后,模型文件将保存在 `saved/dssm_ml-100k` 目录下(或你配置的其他目录)。
|
|
164
|
+
|
|
165
|
+
## 📂 项目结构
|
|
166
|
+
|
|
167
|
+
```
|
|
168
|
+
torch-rechub/ # 根目录
|
|
169
|
+
├── README.md # 项目文档
|
|
170
|
+
├── pyproject.toml # 项目配置和依赖
|
|
171
|
+
├── torch_rechub/ # 核心代码库
|
|
172
|
+
│ ├── basic/ # 基础组件
|
|
173
|
+
│ │ ├── activation.py # 激活函数
|
|
174
|
+
│ │ ├── features.py # 特征工程
|
|
175
|
+
│ │ ├── layers.py # 神经网络层
|
|
176
|
+
│ │ ├── loss_func.py # 损失函数
|
|
177
|
+
│ │ └── metric.py # 评估指标
|
|
178
|
+
│ ├── models/ # 推荐模型实现
|
|
179
|
+
│ │ ├── matching/ # 召回模型(DSSM/MIND/GRU4Rec等)
|
|
180
|
+
│ │ ├── ranking/ # 排序模型(WideDeep/DeepFM/DIN等)
|
|
181
|
+
│ │ └── multi_task/ # 多任务模型(MMoE/ESMM等)
|
|
182
|
+
│ ├── trainers/ # 训练框架
|
|
183
|
+
│ │ ├── ctr_trainer.py # CTR预测训练器
|
|
184
|
+
│ │ ├── match_trainer.py # 召回模型训练器
|
|
185
|
+
│ │ └── mtl_trainer.py # 多任务学习训练器
|
|
186
|
+
│ └── utils/ # 工具函数
|
|
187
|
+
│ ├── data.py # 数据处理工具
|
|
188
|
+
│ ├── match.py # 召回工具
|
|
189
|
+
│ ├── mtl.py # 多任务工具
|
|
190
|
+
│ └── onnx_export.py # ONNX 导出工具
|
|
191
|
+
├── examples/ # 示例脚本
|
|
192
|
+
│ ├── matching/ # 召回任务示例
|
|
193
|
+
│ ├── ranking/ # 排序任务示例
|
|
194
|
+
│ └── generative/ # 生成式推荐示例(HSTU、HLLM 等)
|
|
195
|
+
├── docs/ # 文档(VitePress,多语言)
|
|
196
|
+
├── tutorials/ # Jupyter教程
|
|
197
|
+
├── tests/ # 单元测试
|
|
198
|
+
├── config/ # 配置文件
|
|
199
|
+
└── scripts/ # 工具脚本
|
|
200
|
+
```
|
|
201
|
+
|
|
202
|
+
## 💡 支持的模型
|
|
203
|
+
|
|
204
|
+
本框架目前支持 **30+** 主流推荐模型:
|
|
205
|
+
|
|
206
|
+
### 排序模型 (Ranking Models) - 13个
|
|
207
|
+
|
|
208
|
+
| 模型 | 论文 | 简介 |
|
|
209
|
+
|------|------|------|
|
|
210
|
+
| **DeepFM** | [IJCAI 2017](https://arxiv.org/abs/1703.04247) | FM + Deep 联合训练 |
|
|
211
|
+
| **Wide&Deep** | [DLRS 2016](https://arxiv.org/abs/1606.07792) | 记忆 + 泛化能力结合 |
|
|
212
|
+
| **DCN** | [KDD 2017](https://arxiv.org/abs/1708.05123) | 显式特征交叉网络 |
|
|
213
|
+
| **DCN-v2** | [WWW 2021](https://arxiv.org/abs/2008.13535) | 增强版交叉网络 |
|
|
214
|
+
| **DIN** | [KDD 2018](https://arxiv.org/abs/1706.06978) | 注意力机制捕捉用户兴趣 |
|
|
215
|
+
| **DIEN** | [AAAI 2019](https://arxiv.org/abs/1809.03672) | 兴趣演化建模 |
|
|
216
|
+
| **BST** | [DLP-KDD 2019](https://arxiv.org/abs/1905.06874) | Transformer 序列建模 |
|
|
217
|
+
| **AFM** | [IJCAI 2017](https://arxiv.org/abs/1708.04617) | 注意力因子分解机 |
|
|
218
|
+
| **AutoInt** | [CIKM 2019](https://arxiv.org/abs/1810.11921) | 自动特征交互学习 |
|
|
219
|
+
| **FiBiNET** | [RecSys 2019](https://arxiv.org/abs/1905.09433) | 特征重要性 + 双线性交互 |
|
|
220
|
+
| **DeepFFM** | [RecSys 2019](https://arxiv.org/abs/1611.00144) | 场感知因子分解机 |
|
|
221
|
+
| **EDCN** | [KDD 2021](https://arxiv.org/abs/2106.03032) | 增强型交叉网络 |
|
|
222
|
+
|
|
223
|
+
### 召回模型 (Matching Models) - 12个
|
|
224
|
+
|
|
225
|
+
| 模型 | 论文 | 简介 |
|
|
226
|
+
|------|------|------|
|
|
227
|
+
| **DSSM** | [CIKM 2013](https://posenhuang.github.io/papers/cikm2013_DSSM_fullversion.pdf) | 经典双塔召回模型 |
|
|
228
|
+
| **YoutubeDNN** | [RecSys 2016](https://dl.acm.org/doi/10.1145/2959100.2959190) | YouTube 深度召回 |
|
|
229
|
+
| **YoutubeSBC** | [RecSys 2019](https://dl.acm.org/doi/10.1145/3298689.3346997) | 采样偏差校正版本 |
|
|
230
|
+
| **MIND** | [CIKM 2019](https://arxiv.org/abs/1904.08030) | 多兴趣动态路由 |
|
|
231
|
+
| **SINE** | [WSDM 2021](https://arxiv.org/abs/2103.06920) | 稀疏兴趣网络 |
|
|
232
|
+
| **GRU4Rec** | [ICLR 2016](https://arxiv.org/abs/1511.06939) | GRU 序列推荐 |
|
|
233
|
+
| **SASRec** | [ICDM 2018](https://arxiv.org/abs/1808.09781) | 自注意力序列推荐 |
|
|
234
|
+
| **NARM** | [CIKM 2017](https://arxiv.org/abs/1711.04725) | 神经注意力会话推荐 |
|
|
235
|
+
| **STAMP** | [KDD 2018](https://dl.acm.org/doi/10.1145/3219819.3219895) | 短期注意力记忆优先 |
|
|
236
|
+
| **ComiRec** | [KDD 2020](https://arxiv.org/abs/2005.09347) | 可控多兴趣推荐 |
|
|
237
|
+
|
|
238
|
+
### 多任务模型 (Multi-Task Models) - 5个
|
|
239
|
+
|
|
240
|
+
| 模型 | 论文 | 简介 |
|
|
241
|
+
|------|------|------|
|
|
242
|
+
| **ESMM** | [SIGIR 2018](https://arxiv.org/abs/1804.07931) | 全空间多任务建模 |
|
|
243
|
+
| **MMoE** | [KDD 2018](https://dl.acm.org/doi/10.1145/3219819.3220007) | 多门控专家混合 |
|
|
244
|
+
| **PLE** | [RecSys 2020](https://dl.acm.org/doi/10.1145/3383313.3412236) | 渐进式分层提取 |
|
|
245
|
+
| **AITM** | [KDD 2021](https://arxiv.org/abs/2105.08489) | 自适应信息迁移 |
|
|
246
|
+
| **SharedBottom** | - | 经典多任务共享底层 |
|
|
247
|
+
|
|
248
|
+
### 生成式推荐 (Generative Recommendation) - 2个
|
|
249
|
+
|
|
250
|
+
| 模型 | 论文 | 简介 |
|
|
251
|
+
|------|------|------|
|
|
252
|
+
| **HSTU** | [Meta 2024](https://arxiv.org/abs/2402.17152) | 层级序列转换单元,支撑 Meta 万亿参数推荐系统 |
|
|
253
|
+
| **HLLM** | [2024](https://arxiv.org/abs/2409.12740) | 层级大语言模型推荐,融合 LLM 语义理解能力 |
|
|
254
|
+
|
|
255
|
+
## 📊 支持的数据集
|
|
256
|
+
|
|
257
|
+
框架内置了对以下常见数据集格式的支持或提供了处理脚本:
|
|
258
|
+
|
|
259
|
+
* **MovieLens**
|
|
260
|
+
* **Amazon**
|
|
261
|
+
* **Criteo**
|
|
262
|
+
* **Avazu**
|
|
263
|
+
* **Census-Income**
|
|
264
|
+
* **BookCrossing**
|
|
265
|
+
* **Ali-ccp**
|
|
266
|
+
* **Yidian**
|
|
267
|
+
* ...
|
|
268
|
+
|
|
269
|
+
我们期望的数据格式通常是包含以下字段的交互文件:
|
|
270
|
+
- 用户 ID
|
|
271
|
+
- 物品 ID
|
|
272
|
+
- 评分(可选)
|
|
273
|
+
- 时间戳(可选)
|
|
274
|
+
|
|
275
|
+
具体格式要求请参考 `tutorials` 目录下的示例代码。
|
|
276
|
+
|
|
277
|
+
你可以方便地集成你自己的数据集,只需确保它符合框架要求的数据格式,或编写自定义的数据加载器。
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
## 🧪 示例
|
|
281
|
+
|
|
282
|
+
所有模型使用案例参考 `/examples`
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
### 精排(CTR预测)
|
|
286
|
+
|
|
287
|
+
```python
|
|
288
|
+
from torch_rechub.models.ranking import DeepFM
|
|
289
|
+
from torch_rechub.trainers import CTRTrainer
|
|
290
|
+
from torch_rechub.utils.data import DataGenerator
|
|
291
|
+
|
|
292
|
+
dg = DataGenerator(x, y)
|
|
293
|
+
train_dataloader, val_dataloader, test_dataloader = dg.generate_dataloader(split_ratio=[0.7, 0.1], batch_size=256)
|
|
294
|
+
|
|
295
|
+
model = DeepFM(deep_features=deep_features, fm_features=fm_features, mlp_params={"dims": [256, 128], "dropout": 0.2, "activation": "relu"})
|
|
296
|
+
|
|
297
|
+
ctr_trainer = CTRTrainer(model)
|
|
298
|
+
ctr_trainer.fit(train_dataloader, val_dataloader)
|
|
299
|
+
auc = ctr_trainer.evaluate(ctr_trainer.model, test_dataloader)
|
|
300
|
+
ctr_trainer.export_onnx("deepfm.onnx")
|
|
301
|
+
```
|
|
302
|
+
|
|
303
|
+
### 多任务排序
|
|
304
|
+
|
|
305
|
+
```python
|
|
306
|
+
from torch_rechub.models.multi_task import SharedBottom, ESMM, MMOE, PLE, AITM
|
|
307
|
+
from torch_rechub.trainers import MTLTrainer
|
|
308
|
+
|
|
309
|
+
task_types = ["classification", "classification"]
|
|
310
|
+
model = MMOE(features, task_types, 8, expert_params={"dims": [32,16]}, tower_params_list=[{"dims": [32, 16]}, {"dims": [32, 16]}])
|
|
311
|
+
|
|
312
|
+
mtl_trainer = MTLTrainer(model)
|
|
313
|
+
mtl_trainer.fit(train_dataloader, val_dataloader)
|
|
314
|
+
auc = ctr_trainer.evaluate(ctr_trainer.model, test_dataloader)
|
|
315
|
+
mtl_trainer.export_onnx("mmoe.onnx")
|
|
316
|
+
```
|
|
317
|
+
|
|
318
|
+
### 召回模型
|
|
319
|
+
|
|
320
|
+
```python
|
|
321
|
+
from torch_rechub.models.matching import DSSM
|
|
322
|
+
from torch_rechub.trainers import MatchTrainer
|
|
323
|
+
from torch_rechub.utils.data import MatchDataGenerator
|
|
324
|
+
|
|
325
|
+
dg = MatchDataGenerator(x, y)
|
|
326
|
+
train_dl, test_dl, item_dl = dg.generate_dataloader(test_user, all_item, batch_size=256)
|
|
327
|
+
|
|
328
|
+
model = DSSM(user_features, item_features, temperature=0.02,
|
|
329
|
+
user_params={
|
|
330
|
+
"dims": [256, 128, 64],
|
|
331
|
+
"activation": 'prelu',
|
|
332
|
+
},
|
|
333
|
+
item_params={
|
|
334
|
+
"dims": [256, 128, 64],
|
|
335
|
+
"activation": 'prelu',
|
|
336
|
+
})
|
|
337
|
+
|
|
338
|
+
match_trainer = MatchTrainer(model)
|
|
339
|
+
match_trainer.fit(train_dl)
|
|
340
|
+
match_trainer.export_onnx("dssm.onnx")
|
|
341
|
+
# 双塔模型可分别导出用户塔和物品塔:
|
|
342
|
+
# match_trainer.export_onnx("user_tower.onnx", mode="user")
|
|
343
|
+
# match_trainer.export_onnx("dssm_item.onnx", tower="item")
|
|
344
|
+
```
|
|
345
|
+
|
|
346
|
+
## 👨💻 贡献者
|
|
347
|
+
|
|
348
|
+
感谢所有的贡献者!
|
|
349
|
+
|
|
350
|
+

|
|
351
|
+
|
|
352
|
+
[](https://github.com/datawhalechina/torch-rechub/graphs/contributors)
|
|
353
|
+
|
|
354
|
+
## 🤝 贡献指南
|
|
355
|
+
|
|
356
|
+
我们欢迎各种形式的贡献!请查看 [CONTRIBUTING.md](CONTRIBUTING.md) 了解详细的贡献指南。
|
|
357
|
+
|
|
358
|
+
我们也欢迎通过 [Issues](https://github.com/datawhalechina/torch-rechub/issues) 报告 Bug 或提出功能建议。
|
|
359
|
+
|
|
360
|
+
## 📜 许可证
|
|
361
|
+
|
|
362
|
+
本项目采用 [MIT 许可证](LICENSE)。
|
|
363
|
+
|
|
364
|
+
## 📚 引用
|
|
365
|
+
|
|
366
|
+
如果你在研究或工作中使用了本框架,请考虑引用:
|
|
367
|
+
|
|
368
|
+
```bibtex
|
|
369
|
+
@misc{torch_rechub,
|
|
370
|
+
title = {Torch-RecHub},
|
|
371
|
+
author = {Datawhale},
|
|
372
|
+
year = {2022},
|
|
373
|
+
publisher = {GitHub},
|
|
374
|
+
journal = {GitHub repository},
|
|
375
|
+
howpublished = {\url{https://github.com/datawhalechina/torch-rechub}},
|
|
376
|
+
note = {A PyTorch-based recommender system framework providing easy-to-use and extensible solutions}
|
|
377
|
+
}
|
|
378
|
+
```
|
|
379
|
+
|
|
380
|
+
## 📫 联系方式
|
|
381
|
+
|
|
382
|
+
* **项目负责人:** [1985312383](https://github.com/1985312383)
|
|
383
|
+
* [**GitHub Disscussions**](https://github.com/datawhalechina/torch-rechub/discussions)
|
|
384
|
+
|
|
385
|
+
## ⭐️ 项目 star 历史
|
|
386
|
+
|
|
387
|
+
[](https://www.star-history.com/#datawhalechina/torch-rechub&Date)
|
|
388
|
+
|
|
389
|
+
---
|
|
390
|
+
|
|
391
|
+
*最后更新: [2025-12-04]*
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
torch_rechub/__init__.py,sha256=XUwV85oz-uIokuE9qj3nmbUQg3EY8dZcDMohlob3suw,245
|
|
2
|
+
torch_rechub/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
+
torch_rechub/basic/activation.py,sha256=hIZDCe7cAgV3bX2UnvUrkO8pQs4iXxkQGD0J4GejbVg,1600
|
|
4
|
+
torch_rechub/basic/callback.py,sha256=ZeiDSDQAZUKmyK1AyGJCnqEJ66vwfwlX5lOyu6-h2G0,946
|
|
5
|
+
torch_rechub/basic/features.py,sha256=TLHR5EaNvIbKyKd730Qt8OlLpV0Km91nv2TMnq0HObk,3562
|
|
6
|
+
torch_rechub/basic/initializers.py,sha256=V6hprXvRexcw3vrYsf8Qp-F52fp8uzPMpa1CvkHofy8,3196
|
|
7
|
+
torch_rechub/basic/layers.py,sha256=URWk78dlffMOAhDVDhOhugcr4nmwEa192AI1diktC-4,39653
|
|
8
|
+
torch_rechub/basic/loss_func.py,sha256=6bjljqpiuUP6O8-wUbGd8FSvflY5Dp_DV_57OuQVMz4,7969
|
|
9
|
+
torch_rechub/basic/metaoptimizer.py,sha256=y-oT4MV3vXnSQ5Zd_ZEHP1KClITEi3kbZa6RKjlkYw8,3093
|
|
10
|
+
torch_rechub/basic/metric.py,sha256=9JsaJJGvT6VRvsLoM2Y171CZxESsjYTofD3qnMI-bPM,8443
|
|
11
|
+
torch_rechub/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
|
+
torch_rechub/models/generative/__init__.py,sha256=TsCdVIhOcalQwqKZKjEuNbHKyIjyclapKGNwYfFR7TM,135
|
|
13
|
+
torch_rechub/models/generative/hllm.py,sha256=6Vrp5Bh0fTFHCn7C-3EqzOyc7UunOyEY9TzAKGHrW-8,9669
|
|
14
|
+
torch_rechub/models/generative/hstu.py,sha256=M1ByAWHxrkvmwaXPNdGFrbAQYyYJswnABTY3jCEceyg,7846
|
|
15
|
+
torch_rechub/models/matching/__init__.py,sha256=fjWOzJB8loPGy8rJMG-6G-NUISp7k3sD_1FdsKGw1as,471
|
|
16
|
+
torch_rechub/models/matching/comirec.py,sha256=8KB5rg2FWlZaG73CBI7_J8-J-XpjTPnblwh5OsbtAbc,9439
|
|
17
|
+
torch_rechub/models/matching/dssm.py,sha256=1Q0JYpt1h_7NWlLN5a_RbCoUSubZwpYTVEXccSn44eg,3003
|
|
18
|
+
torch_rechub/models/matching/dssm_facebook.py,sha256=n3MS7FT_kyJSDnVTlPCv_nPJ0MHCtMgJRUPDRh7jBLM,3508
|
|
19
|
+
torch_rechub/models/matching/dssm_senet.py,sha256=_E-xEh44XvOaBHP8XdSRkFsTvajhovxlYyCt3H9P61c,4052
|
|
20
|
+
torch_rechub/models/matching/gru4rec.py,sha256=cJtYCkFyg3cPYkOy_YeXRAsTev0cBPiicrj68xJup9k,3932
|
|
21
|
+
torch_rechub/models/matching/mind.py,sha256=NIUeqWhrnZeiFDMNFvXfMx1GMBMaCZnc6nxNZCJpwSE,4933
|
|
22
|
+
torch_rechub/models/matching/narm.py,sha256=2dlTuan9AFrku53WJlBbTwgLlfOHsas3-JBFGxEz7oE,3167
|
|
23
|
+
torch_rechub/models/matching/sasrec.py,sha256=QDfKrFl-aduWg6rY3R13RrdpMiApVugDmtEsWJulgzg,5534
|
|
24
|
+
torch_rechub/models/matching/sine.py,sha256=sUTUHbnewdSBd51epDIp9j-B1guKkhm6eM-KkZ3oS3Q,6746
|
|
25
|
+
torch_rechub/models/matching/stamp.py,sha256=DBVM3iCoQTBKwO7oKHg5SCCDXqTuRJ4Ko1n7StgEovA,3308
|
|
26
|
+
torch_rechub/models/matching/youtube_dnn.py,sha256=EQV_GoEs2Hxwg1U3Dj7-lWkEejEqGmtZ7D9CgfknQdA,3368
|
|
27
|
+
torch_rechub/models/matching/youtube_sbc.py,sha256=paw9uRnbNw_-EaFpRogy7rB4vhw4KN0Qf8BfQylTj4I,4757
|
|
28
|
+
torch_rechub/models/multi_task/__init__.py,sha256=5N8aJ32fzxniDm4d-AeNSi81CFWyBhjoSaK3OC-XCkY,189
|
|
29
|
+
torch_rechub/models/multi_task/aitm.py,sha256=hlG4opauSmM4SNZBPqogPZKNCPheGCVx--JSEWXeIJ8,3355
|
|
30
|
+
torch_rechub/models/multi_task/esmm.py,sha256=y6Gv_mWRw7lcZm-wjw2OHJVwvDHFfvkVSnJvQGO6kUk,2742
|
|
31
|
+
torch_rechub/models/multi_task/mmoe.py,sha256=yprs9P5vL6C4mf3lHb0uyjyJisYVny5UZ6Q9MGVfn-0,2831
|
|
32
|
+
torch_rechub/models/multi_task/ple.py,sha256=5QieTfL-jhHMzuSOlmgXc0dGARJzpLT2sHHdMedjtpM,6448
|
|
33
|
+
torch_rechub/models/multi_task/shared_bottom.py,sha256=XH32YOGsb8M97ggcm6RPY2k3T5g9GgeZsJ2NjwpZoko,1866
|
|
34
|
+
torch_rechub/models/ranking/__init__.py,sha256=VWhbC4918N1hDiUIj7shHp-N7oWOYYpLc4siBtMUV0w,447
|
|
35
|
+
torch_rechub/models/ranking/afm.py,sha256=VbTa3HSyYmpTMdSOr54ex6HOn_uA9tYtmlwX7IKDhUs,2267
|
|
36
|
+
torch_rechub/models/ranking/autoint.py,sha256=VQ8BrPYaSXZzDcSvQ0spcxl5rSqNL-gqemH7UKiLUDk,3880
|
|
37
|
+
torch_rechub/models/ranking/bst.py,sha256=b9cZUnFV51Fw4UAOuPJpog9pHMb2cuGr-yNAvqK7Xpg,3632
|
|
38
|
+
torch_rechub/models/ranking/dcn.py,sha256=ErWAF9MlKsZj_aF1DQQa1cHkobMyo7J5CTreUptRzhw,1302
|
|
39
|
+
torch_rechub/models/ranking/dcn_v2.py,sha256=09mx4LP8D1CVCpFZpJmjCWUm6OzJ9lXgyAvSyaaxzQo,2832
|
|
40
|
+
torch_rechub/models/ranking/deepffm.py,sha256=XYYnpwsJ028fyNcH1MWGPWQVrlonO5aiZO7bpC25PnE,6044
|
|
41
|
+
torch_rechub/models/ranking/deepfm.py,sha256=5yKLrdLPuD3NigGL3bKnG5HS3kqCuRw7gAIkXl6qY9Q,1783
|
|
42
|
+
torch_rechub/models/ranking/dien.py,sha256=2jaPluJf0K_ctHM3MTJ042bXxGLI8KphOTkKVScKUAg,8817
|
|
43
|
+
torch_rechub/models/ranking/din.py,sha256=HsOCEErea3KwEiyWw4M_aX_LMC_-Sqs1C_zeRLKLV_c,4542
|
|
44
|
+
torch_rechub/models/ranking/edcn.py,sha256=6f_S8I6Ir16kCIU54R4EfumWfUFOND5KDKUPHMgsVU0,4997
|
|
45
|
+
torch_rechub/models/ranking/fibinet.py,sha256=fmEJ9WkO8Mn0RtK_8aRHlnQFh_jMBPO0zODoHZPWmDA,2234
|
|
46
|
+
torch_rechub/models/ranking/widedeep.py,sha256=eciRvWRBHLlctabLLS5NB7k3MnqrWXCBdpflOU6jMB0,1636
|
|
47
|
+
torch_rechub/trainers/__init__.py,sha256=NSa2DqgfE1HGDyj40YgrbtUrfBHBxNBpw57XtaAB_jE,148
|
|
48
|
+
torch_rechub/trainers/ctr_trainer.py,sha256=RDUXkn7GwLzs3f0kWZwGDNCpqiMeGXo7R6ezFeZdPg8,9075
|
|
49
|
+
torch_rechub/trainers/match_trainer.py,sha256=xox5eaPKjSgErJQpbSr29sbyGs1p2sFaKEjxACE6uMI,11276
|
|
50
|
+
torch_rechub/trainers/matching.md,sha256=vIBQ3UMmVpUpyk38rrkelFwm_wXVXqMOuqzYZ4M8bzw,30
|
|
51
|
+
torch_rechub/trainers/mtl_trainer.py,sha256=tC4c2KIc-H8Wvj4qCzcW6TyfMLRPJyfQvTaN0dDePFg,12598
|
|
52
|
+
torch_rechub/trainers/seq_trainer.py,sha256=lXKRx7XbZ3iJuqp_f05vw_jkn8X5j8HmH6Nr-typiIU,12043
|
|
53
|
+
torch_rechub/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
54
|
+
torch_rechub/utils/data.py,sha256=vzLAAVt6dujg_vbGhQewiJc0l6JzwzdcM_9EjoOz898,19882
|
|
55
|
+
torch_rechub/utils/hstu_utils.py,sha256=qLON_pJDC-kDyQn1PoN_HaHi5xTNCwZPgJeV51Z61Lc,6207
|
|
56
|
+
torch_rechub/utils/match.py,sha256=l9qDwJGHPP9gOQTMYoqGVdWrlhDx1F1-8UnQwDWrEyk,18143
|
|
57
|
+
torch_rechub/utils/mtl.py,sha256=AxU05ezizCuLdbPuCg1ZXE0WAStzuxaS5Sc3nwMCBpI,5737
|
|
58
|
+
torch_rechub/utils/onnx_export.py,sha256=uRcAD4uZ3eIQbM-DPhdc0bkaPaslNsOYny6BOeLVBfU,13660
|
|
59
|
+
torch_rechub-0.0.4.dist-info/METADATA,sha256=SNm71v_YOfculnc13p266bD_8yLo0U_16F_aJQPDvYo,16149
|
|
60
|
+
torch_rechub-0.0.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
61
|
+
torch_rechub-0.0.4.dist-info/licenses/LICENSE,sha256=V7ietiX9G_84HtgEbxDgxClniqXGm2t5q8WM4AHGTu0,1066
|
|
62
|
+
torch_rechub-0.0.4.dist-info/RECORD,,
|
torch_rechub/basic/utils.py
DELETED
|
@@ -1,168 +0,0 @@
|
|
|
1
|
-
import random
|
|
2
|
-
import torch
|
|
3
|
-
import numpy as np
|
|
4
|
-
import pandas as pd
|
|
5
|
-
from torch.utils.data import Dataset, DataLoader, random_split
|
|
6
|
-
from sklearn.metrics import roc_auc_score, mean_squared_error
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class TorchDataset(Dataset):
|
|
10
|
-
|
|
11
|
-
def __init__(self, x, y):
|
|
12
|
-
super(TorchDataset, self).__init__()
|
|
13
|
-
self.x = x
|
|
14
|
-
self.y = y
|
|
15
|
-
|
|
16
|
-
def __getitem__(self, index):
|
|
17
|
-
return {k: v[index] for k, v in self.x.items()}, self.y[index]
|
|
18
|
-
|
|
19
|
-
def __len__(self):
|
|
20
|
-
return len(self.y)
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class DataGenerator(object):
|
|
24
|
-
|
|
25
|
-
def __init__(self, x, y):
|
|
26
|
-
super(DataGenerator, self).__init__()
|
|
27
|
-
self.dataset = TorchDataset(x, y)
|
|
28
|
-
self.length = len(self.dataset)
|
|
29
|
-
|
|
30
|
-
def generate_dataloader(self, x_val=None, y_val=None, x_test=None, y_test=None, split_ratio=None, batch_size=16, num_workers=8):
|
|
31
|
-
if split_ratio != None:
|
|
32
|
-
train_length = int(self.length * split_ratio[0])
|
|
33
|
-
val_length = int(self.length * split_ratio[1])
|
|
34
|
-
test_length = self.length - train_length - val_length
|
|
35
|
-
print("the samples of train : val : test are %d : %d : %d" % (train_length, val_length, test_length))
|
|
36
|
-
train_dataset, val_dataset, test_dataset = random_split(self.dataset, (train_length, val_length, test_length))
|
|
37
|
-
else:
|
|
38
|
-
train_dataset = self.dataset
|
|
39
|
-
val_dataset = TorchDataset(x_val, y_val)
|
|
40
|
-
test_dataset = TorchDataset(x_test, y_test)
|
|
41
|
-
|
|
42
|
-
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers)
|
|
43
|
-
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers)
|
|
44
|
-
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers)
|
|
45
|
-
return train_dataloader, val_dataloader, test_dataloader
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
class PredictDataset(Dataset):
|
|
49
|
-
|
|
50
|
-
def __init__(self, x):
|
|
51
|
-
super(TorchDataset, self).__init__()
|
|
52
|
-
self.x = x
|
|
53
|
-
|
|
54
|
-
def __getitem__(self, index):
|
|
55
|
-
return {k: v[index] for k, v in self.x.items()}
|
|
56
|
-
|
|
57
|
-
def __len__(self):
|
|
58
|
-
return len(self.x[list(self.x.keys())[0]])
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
def get_auto_embedding_dim(num_classes):
|
|
62
|
-
""" Calculate the dim of embedding vector according to number of classes in the category
|
|
63
|
-
emb_dim = [6 * (num_classes)^(1/4)]
|
|
64
|
-
reference: Deep & Cross Network for Ad Click Predictions.(ADKDD'17)
|
|
65
|
-
|
|
66
|
-
Args:
|
|
67
|
-
num_classes: number of classes in the category
|
|
68
|
-
|
|
69
|
-
Returns:
|
|
70
|
-
the dim of embedding vector
|
|
71
|
-
"""
|
|
72
|
-
return np.floor(6 * np.pow(num_classes, 0.26))
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
def get_loss_func(task_type="classification"):
|
|
76
|
-
if task_type == "classification":
|
|
77
|
-
return torch.nn.BCELoss()
|
|
78
|
-
elif task_type == "regression":
|
|
79
|
-
return torch.nn.MSELoss()
|
|
80
|
-
else:
|
|
81
|
-
raise ValueError("task_type must be classification or regression")
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
def get_metric_func(task_type="classification"):
|
|
85
|
-
if task_type == "classification":
|
|
86
|
-
return roc_auc_score
|
|
87
|
-
elif task_type == "regression":
|
|
88
|
-
return mean_squared_error
|
|
89
|
-
else:
|
|
90
|
-
raise ValueError("task_type must be classification or regression")
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
def create_seq_features(data, max_len=50, drop_short=3, shuffle=True):
|
|
94
|
-
"""Build a sequence of user's history by time.
|
|
95
|
-
|
|
96
|
-
Args:
|
|
97
|
-
data (pd.DataFrame): must contain keys: `user_id, item_id, cate_id, time`.
|
|
98
|
-
max_len (int): the max length of a user history sequence.
|
|
99
|
-
drop_short (int): remove some inactive user who's sequence length < drop_short.
|
|
100
|
-
shuffle (bool): shuffle data if true.
|
|
101
|
-
|
|
102
|
-
Returns:
|
|
103
|
-
train (pd.DataFrame): target item will be each item before last two items.
|
|
104
|
-
val (pd.DataFrame): target item is the second to last item of user's history sequence.
|
|
105
|
-
test (pd.DataFrame): target item is the last item of user's history sequence.
|
|
106
|
-
"""
|
|
107
|
-
n_users, n_items, n_cates = data["user_id"].max(), data["item_id"].max(), data["cate_id"].max()
|
|
108
|
-
# 0 to be used as the symbol for padding
|
|
109
|
-
data = data.astype('int32')
|
|
110
|
-
data['item_id'] = data['item_id'].apply(lambda x: x + 1)
|
|
111
|
-
data['cate_id'] = data['cate_id'].apply(lambda x: x + 1)
|
|
112
|
-
|
|
113
|
-
item_cate_map = data[['item_id', 'cate_id']]
|
|
114
|
-
item2cate_dict = item_cate_map.set_index(['item_id'])['cate_id'].to_dict()
|
|
115
|
-
|
|
116
|
-
data = data.sort_values(['user_id', 'time']).groupby('user_id').agg(click_hist_list=('item_id', list), cate_hist_hist=('cate_id', list)).reset_index()
|
|
117
|
-
|
|
118
|
-
# Sliding window to construct negative samples
|
|
119
|
-
train_data, val_data, test_data = [], [], []
|
|
120
|
-
for item in data.itertuples():
|
|
121
|
-
if len(item[2]) < drop_short:
|
|
122
|
-
continue
|
|
123
|
-
click_hist_list = item[2][:max_len]
|
|
124
|
-
cate_hist_list = item[3][:max_len]
|
|
125
|
-
|
|
126
|
-
def neg_sample():
|
|
127
|
-
neg = click_hist_list[0]
|
|
128
|
-
while neg in click_hist_list:
|
|
129
|
-
neg = random.randint(1, n_items)
|
|
130
|
-
return neg
|
|
131
|
-
|
|
132
|
-
neg_list = [neg_sample() for _ in range(len(click_hist_list))]
|
|
133
|
-
hist_list = []
|
|
134
|
-
cate_list = []
|
|
135
|
-
for i in range(1, len(click_hist_list)):
|
|
136
|
-
hist_list.append(click_hist_list[i - 1])
|
|
137
|
-
cate_list.append(cate_hist_list[i - 1])
|
|
138
|
-
hist_list_pad = hist_list + [0] * (max_len - len(hist_list))
|
|
139
|
-
cate_list_pad = cate_list + [0] * (max_len - len(cate_list))
|
|
140
|
-
if i == len(click_hist_list) - 1:
|
|
141
|
-
test_data.append([hist_list_pad, cate_list_pad, click_hist_list[i], cate_hist_list[i], 1])
|
|
142
|
-
test_data.append([hist_list_pad, cate_list_pad, neg_list[i], item2cate_dict[neg_list[i]], 0])
|
|
143
|
-
if i == len(click_hist_list) - 2:
|
|
144
|
-
val_data.append([hist_list_pad, cate_list_pad, click_hist_list[i], cate_hist_list[i], 1])
|
|
145
|
-
val_data.append([hist_list_pad, cate_list_pad, neg_list[i], item2cate_dict[neg_list[i]], 0])
|
|
146
|
-
else:
|
|
147
|
-
train_data.append([hist_list_pad, cate_list_pad, click_hist_list[i], cate_hist_list[i], 1])
|
|
148
|
-
train_data.append([hist_list_pad, cate_list_pad, neg_list[i], item2cate_dict[neg_list[i]], 0])
|
|
149
|
-
|
|
150
|
-
# shuffle
|
|
151
|
-
if shuffle:
|
|
152
|
-
random.shuffle(train_data)
|
|
153
|
-
random.shuffle(val_data)
|
|
154
|
-
random.shuffle(test_data)
|
|
155
|
-
|
|
156
|
-
col_name = ['history_item', 'history_cate', 'target_item', 'target_cate', 'label']
|
|
157
|
-
train = pd.DataFrame(train_data, columns=col_name)
|
|
158
|
-
val = pd.DataFrame(val_data, columns=col_name)
|
|
159
|
-
test = pd.DataFrame(test_data, columns=col_name)
|
|
160
|
-
|
|
161
|
-
return train, val, test
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
def df_to_input_dict(data):
|
|
165
|
-
data_dict = data.to_dict('list')
|
|
166
|
-
for key in data.keys():
|
|
167
|
-
data_dict[key] = np.array(data_dict[key])
|
|
168
|
-
return data_dict
|