torch-rechub 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/basic/activation.py +54 -52
- torch_rechub/basic/callback.py +32 -32
- torch_rechub/basic/features.py +94 -57
- torch_rechub/basic/initializers.py +92 -0
- torch_rechub/basic/layers.py +720 -240
- torch_rechub/basic/loss_func.py +34 -0
- torch_rechub/basic/metaoptimizer.py +72 -0
- torch_rechub/basic/metric.py +250 -0
- torch_rechub/models/matching/__init__.py +11 -0
- torch_rechub/models/matching/comirec.py +188 -0
- torch_rechub/models/matching/dssm.py +66 -0
- torch_rechub/models/matching/dssm_facebook.py +79 -0
- torch_rechub/models/matching/dssm_senet.py +75 -0
- torch_rechub/models/matching/gru4rec.py +87 -0
- torch_rechub/models/matching/mind.py +101 -0
- torch_rechub/models/matching/narm.py +76 -0
- torch_rechub/models/matching/sasrec.py +140 -0
- torch_rechub/models/matching/sine.py +151 -0
- torch_rechub/models/matching/stamp.py +83 -0
- torch_rechub/models/matching/youtube_dnn.py +71 -0
- torch_rechub/models/matching/youtube_sbc.py +98 -0
- torch_rechub/models/multi_task/__init__.py +5 -4
- torch_rechub/models/multi_task/aitm.py +84 -0
- torch_rechub/models/multi_task/esmm.py +55 -45
- torch_rechub/models/multi_task/mmoe.py +58 -52
- torch_rechub/models/multi_task/ple.py +130 -104
- torch_rechub/models/multi_task/shared_bottom.py +45 -44
- torch_rechub/models/ranking/__init__.py +11 -3
- torch_rechub/models/ranking/afm.py +63 -0
- torch_rechub/models/ranking/bst.py +63 -0
- torch_rechub/models/ranking/dcn.py +38 -0
- torch_rechub/models/ranking/dcn_v2.py +69 -0
- torch_rechub/models/ranking/deepffm.py +123 -0
- torch_rechub/models/ranking/deepfm.py +41 -41
- torch_rechub/models/ranking/dien.py +191 -0
- torch_rechub/models/ranking/din.py +91 -81
- torch_rechub/models/ranking/edcn.py +117 -0
- torch_rechub/models/ranking/fibinet.py +50 -0
- torch_rechub/models/ranking/widedeep.py +41 -41
- torch_rechub/trainers/__init__.py +2 -1
- torch_rechub/trainers/{trainer.py → ctr_trainer.py} +128 -111
- torch_rechub/trainers/match_trainer.py +170 -0
- torch_rechub/trainers/mtl_trainer.py +206 -144
- torch_rechub/utils/__init__.py +0 -0
- torch_rechub/utils/data.py +360 -0
- torch_rechub/utils/match.py +274 -0
- torch_rechub/utils/mtl.py +126 -0
- {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.3.dist-info}/LICENSE +21 -21
- torch_rechub-0.0.3.dist-info/METADATA +177 -0
- torch_rechub-0.0.3.dist-info/RECORD +55 -0
- {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.3.dist-info}/WHEEL +1 -1
- torch_rechub/basic/utils.py +0 -168
- torch_rechub-0.0.1.dist-info/METADATA +0 -105
- torch_rechub-0.0.1.dist-info/RECORD +0 -26
- {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.optim.optimizer import Optimizer
|
|
3
|
+
from ..models.multi_task import MMOE, SharedBottom, PLE, AITM
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def shared_task_layers(model):
|
|
7
|
+
"""get shared layers and task layers in multi-task model
|
|
8
|
+
Authors: Qida Dong, dongjidan@126.com
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
model (torch.nn.Module): only support `[MMOE, SharedBottom, PLE, AITM]`
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
list[torch.nn.parameter]: parameters split to shared list and task list.
|
|
15
|
+
"""
|
|
16
|
+
shared_layers = list(model.embedding.parameters())
|
|
17
|
+
task_layers = None
|
|
18
|
+
if isinstance(model, SharedBottom):
|
|
19
|
+
shared_layers += list(model.bottom_mlp.parameters())
|
|
20
|
+
task_layers = list(model.towers.parameters()) + list(model.predict_layers.parameters())
|
|
21
|
+
elif isinstance(model, MMOE):
|
|
22
|
+
shared_layers += list(model.experts.parameters())
|
|
23
|
+
task_layers = list(model.towers.parameters()) + list(model.predict_layers.parameters())
|
|
24
|
+
task_layers += list(model.gates.parameters())
|
|
25
|
+
elif isinstance(model, PLE):
|
|
26
|
+
shared_layers += list(model.cgc_layers.parameters())
|
|
27
|
+
task_layers = list(model.towers.parameters()) + list(model.predict_layers.parameters())
|
|
28
|
+
elif isinstance(model, AITM):
|
|
29
|
+
shared_layers += list(model.bottoms.parameters())
|
|
30
|
+
task_layers = list(model.info_gates.parameters()) + list(model.towers.parameters()) + list(
|
|
31
|
+
model.aits.parameters())
|
|
32
|
+
else:
|
|
33
|
+
raise ValueError(f'this model {model} is not suitable for MetaBalance Optimizer')
|
|
34
|
+
return shared_layers, task_layers
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class MetaBalance(Optimizer):
|
|
38
|
+
"""MetaBalance Optimizer
|
|
39
|
+
This method is used to scale the gradient and balance the gradient of each task.
|
|
40
|
+
Authors: Qida Dong, dongjidan@126.com
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
parameters (list): the parameters of model
|
|
44
|
+
relax_factor (float, optional): the relax factor of gradient scaling (default: 0.7)
|
|
45
|
+
beta (float, optional): the coefficient of moving average (default: 0.9)
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(self, parameters, relax_factor=0.7, beta=0.9):
|
|
49
|
+
|
|
50
|
+
if relax_factor < 0. or relax_factor >= 1.:
|
|
51
|
+
raise ValueError(f'Invalid relax_factor: {relax_factor}, it should be 0. <= relax_factor < 1.')
|
|
52
|
+
if beta < 0. or beta >= 1.:
|
|
53
|
+
raise ValueError(f'Invalid beta: {beta}, it should be 0. <= beta < 1.')
|
|
54
|
+
rel_beta_dict = {'relax_factor': relax_factor, 'beta': beta}
|
|
55
|
+
super(MetaBalance, self).__init__(parameters, rel_beta_dict)
|
|
56
|
+
|
|
57
|
+
@torch.no_grad()
|
|
58
|
+
def step(self, losses):
|
|
59
|
+
for idx, loss in enumerate(losses):
|
|
60
|
+
loss.backward(retain_graph=True)
|
|
61
|
+
for group in self.param_groups:
|
|
62
|
+
for gp in group['params']:
|
|
63
|
+
if gp.grad is None:
|
|
64
|
+
# print('breaking')
|
|
65
|
+
break
|
|
66
|
+
if gp.grad.is_sparse:
|
|
67
|
+
raise RuntimeError('MetaBalance does not support sparse gradients')
|
|
68
|
+
# store the result of moving average
|
|
69
|
+
state = self.state[gp]
|
|
70
|
+
if len(state) == 0:
|
|
71
|
+
for i in range(len(losses)):
|
|
72
|
+
if i == 0:
|
|
73
|
+
gp.norms = [0]
|
|
74
|
+
else:
|
|
75
|
+
gp.norms.append(0)
|
|
76
|
+
# calculate the moving average
|
|
77
|
+
beta = group['beta']
|
|
78
|
+
gp.norms[idx] = gp.norms[idx] * beta + (1 - beta) * torch.norm(gp.grad)
|
|
79
|
+
# scale the auxiliary gradient
|
|
80
|
+
relax_factor = group['relax_factor']
|
|
81
|
+
gp.grad = gp.grad * gp.norms[0] / (gp.norms[idx] + 1e-5) * relax_factor + gp.grad * (1. -
|
|
82
|
+
relax_factor)
|
|
83
|
+
# store the gradient of each auxiliary task in state
|
|
84
|
+
if idx == 0:
|
|
85
|
+
state['sum_gradient'] = torch.zeros_like(gp.data)
|
|
86
|
+
state['sum_gradient'] += gp.grad
|
|
87
|
+
else:
|
|
88
|
+
state['sum_gradient'] += gp.grad
|
|
89
|
+
|
|
90
|
+
if gp.grad is not None:
|
|
91
|
+
gp.grad.detach_()
|
|
92
|
+
gp.grad.zero_()
|
|
93
|
+
if idx == len(losses) - 1:
|
|
94
|
+
gp.grad = state['sum_gradient']
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def gradnorm(loss_list, loss_weight, share_layer, initial_task_loss, alpha):
|
|
98
|
+
loss = 0
|
|
99
|
+
for loss_i, w_i in zip(loss_list, loss_weight):
|
|
100
|
+
loss += loss_i * w_i
|
|
101
|
+
loss.backward(retain_graph=True)
|
|
102
|
+
# set the gradients of w_i(t) to zero because these gradients have to be updated using the GradNorm loss
|
|
103
|
+
for w_i in loss_weight:
|
|
104
|
+
w_i.grad.data = w_i.grad.data * 0.0
|
|
105
|
+
# get the gradient norms for each of the tasks
|
|
106
|
+
# G^{(i)}_w(t)
|
|
107
|
+
norms, loss_ratio = [], []
|
|
108
|
+
for i in range(len(loss_list)):
|
|
109
|
+
# get the gradient of this task loss with respect to the shared parameters
|
|
110
|
+
gygw = torch.autograd.grad(loss_list[i], share_layer, retain_graph=True)
|
|
111
|
+
# compute the norm
|
|
112
|
+
norms.append(torch.norm(torch.mul(loss_weight[i], gygw[0])))
|
|
113
|
+
# compute the inverse training rate r_i(t)
|
|
114
|
+
loss_ratio.append(loss_list[i].item() / initial_task_loss[i])
|
|
115
|
+
norms = torch.stack(norms)
|
|
116
|
+
mean_norm = torch.mean(norms.detach())
|
|
117
|
+
mean_loss_ratio = sum(loss_ratio) / len(loss_ratio)
|
|
118
|
+
# compute the GradNorm loss
|
|
119
|
+
# this term has to remain constant
|
|
120
|
+
constant_term = mean_norm * (mean_loss_ratio**alpha)
|
|
121
|
+
grad_norm_loss = torch.sum(torch.abs(norms - constant_term))
|
|
122
|
+
#print('GradNorm loss {}'.format(grad_norm_loss))
|
|
123
|
+
|
|
124
|
+
# compute the gradient for the weights
|
|
125
|
+
for w_i in loss_weight:
|
|
126
|
+
w_i.grad = torch.autograd.grad(grad_norm_loss, w_i, retain_graph=True)[0]
|
|
@@ -1,21 +1,21 @@
|
|
|
1
|
-
MIT License
|
|
2
|
-
|
|
3
|
-
Copyright (c) 2022
|
|
4
|
-
|
|
5
|
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
-
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
-
in the Software without restriction, including without limitation the rights
|
|
8
|
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
-
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
-
furnished to do so, subject to the following conditions:
|
|
11
|
-
|
|
12
|
-
The above copyright notice and this permission notice shall be included in all
|
|
13
|
-
copies or substantial portions of the Software.
|
|
14
|
-
|
|
15
|
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
-
SOFTWARE.
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2022 Datawhale
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: torch-rechub
|
|
3
|
+
Version: 0.0.3
|
|
4
|
+
Summary: A Lighting Pytorch Framework for Recommendation System, Easy-to-use and Easy-to-extend.
|
|
5
|
+
Home-page: https://github.com/datawhalechina/torch-rechub
|
|
6
|
+
Author: Datawhale
|
|
7
|
+
Author-email: laimc@shanghaitech.edu.cn
|
|
8
|
+
Keywords: ctr,click through rate,deep learning,pytorch,recsys,recommendation
|
|
9
|
+
Platform: all
|
|
10
|
+
Classifier: Intended Audience :: Developers
|
|
11
|
+
Classifier: Intended Audience :: Education
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: Operating System :: OS Independent
|
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.8
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
20
|
+
Classifier: Topic :: Scientific/Engineering
|
|
21
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
22
|
+
Classifier: Topic :: Software Development :: Libraries
|
|
23
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
24
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
25
|
+
Requires-Python: >=3.8
|
|
26
|
+
Description-Content-Type: text/markdown
|
|
27
|
+
License-File: LICENSE
|
|
28
|
+
Requires-Dist: numpy>=1.19.0
|
|
29
|
+
Requires-Dist: torch>=1.7.0
|
|
30
|
+
Requires-Dist: pandas>=1.0.5
|
|
31
|
+
Requires-Dist: tqdm>=4.64.0
|
|
32
|
+
Requires-Dist: scikit-learn>=0.23.2
|
|
33
|
+
Requires-Dist: annoy>=1.17.0
|
|
34
|
+
|
|
35
|
+
# Torch-RecHub
|
|
36
|
+
|
|
37
|
+
<p align="left">
|
|
38
|
+
<img src='https://img.shields.io/badge/python-3.8+-brightgreen'>
|
|
39
|
+
<img src='https://img.shields.io/badge/torch-1.7+-brightgreen'>
|
|
40
|
+
<img src='https://img.shields.io/badge/scikit_learn-0.23.2+-brightgreen'>
|
|
41
|
+
<img src='https://img.shields.io/badge/pandas-1.0.5+-brightgreen'>
|
|
42
|
+
<img src='https://img.shields.io/badge/annoy-1.17.0-brightgreen'>
|
|
43
|
+
<img src="https://img.shields.io/pypi/l/torch-rechub">
|
|
44
|
+
<a href="https://github.com/datawhalechina/torch-rechub"><img src="https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2Fdatawhalechina%2Ftorch-rechub&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=hits&edge_flat=false"/></a>
|
|
45
|
+
|
|
46
|
+
## 中文Wiki站
|
|
47
|
+
|
|
48
|
+
查看最新研发进度,认领感兴趣的研发任务,学习rechub模型复现心得,加入rechub共建者团队等
|
|
49
|
+
|
|
50
|
+
[点击链接](https://www.wolai.com/rechub/2qjdg3DPy1179e1vpcHZQC)
|
|
51
|
+
|
|
52
|
+
## 安装
|
|
53
|
+
|
|
54
|
+
```python
|
|
55
|
+
#稳定版
|
|
56
|
+
pip install torch-rechub
|
|
57
|
+
|
|
58
|
+
#最新版(推荐)
|
|
59
|
+
1. git clone https://github.com/datawhalechina/torch-rechub.git
|
|
60
|
+
2. cd torch-rechub
|
|
61
|
+
3. python setup.py install
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
## 核心定位
|
|
65
|
+
|
|
66
|
+
易用易拓展,聚焦复现业界实用的推荐模型,以及泛生态化的推荐场景
|
|
67
|
+
|
|
68
|
+
## 主要特性
|
|
69
|
+
|
|
70
|
+
* scikit-learn风格易用的API(fit、predict),即插即用
|
|
71
|
+
|
|
72
|
+
* 模型训练与模型定义解耦,易拓展,可针对不同类型的模型设置不同的训练机制
|
|
73
|
+
|
|
74
|
+
* 接受pandas的DataFrame、Dict数据输入,上手成本低
|
|
75
|
+
|
|
76
|
+
* 高度模块化,支持常见Layer,容易调用组装成新模型
|
|
77
|
+
|
|
78
|
+
* LR、MLP、FM、FFM、CIN
|
|
79
|
+
|
|
80
|
+
* target-attention、self-attention、transformer
|
|
81
|
+
|
|
82
|
+
* 支持常见排序模型
|
|
83
|
+
|
|
84
|
+
* WideDeep、DeepFM、DIN、DCN、xDeepFM等
|
|
85
|
+
|
|
86
|
+
* 支持常见召回模型
|
|
87
|
+
|
|
88
|
+
* DSSM、YoutubeDNN、YoutubeDSSM、FacebookEBR、MIND等
|
|
89
|
+
|
|
90
|
+
* 丰富的多任务学习支持
|
|
91
|
+
|
|
92
|
+
* SharedBottom、ESMM、MMOE、PLE、AITM等模型
|
|
93
|
+
|
|
94
|
+
* GradNorm、UWL、MetaBanlance等动态loss加权机制
|
|
95
|
+
|
|
96
|
+
* 聚焦更生态化的推荐场景
|
|
97
|
+
|
|
98
|
+
- [ ] 冷启动
|
|
99
|
+
|
|
100
|
+
- [ ] 延迟反馈
|
|
101
|
+
|
|
102
|
+
* [ ] 去偏
|
|
103
|
+
|
|
104
|
+
* 支持丰富的训练机制
|
|
105
|
+
|
|
106
|
+
* [ ] 对比学习
|
|
107
|
+
|
|
108
|
+
* [ ] 蒸馏学习
|
|
109
|
+
|
|
110
|
+
* [ ] 第三方高性能开源Trainer支持(Pytorch Lighting)
|
|
111
|
+
|
|
112
|
+
* [ ] 更多模型正在开发中
|
|
113
|
+
|
|
114
|
+
## 快速使用
|
|
115
|
+
|
|
116
|
+
### 使用案例
|
|
117
|
+
|
|
118
|
+
- 所有模型使用案例参考 `/examples`
|
|
119
|
+
|
|
120
|
+
- 202206 Datawhale-RecHub推荐课程 组队学习期间notebook教程参考 `/tutorials`
|
|
121
|
+
|
|
122
|
+
### 精排(CTR预测)
|
|
123
|
+
|
|
124
|
+
```python
|
|
125
|
+
from torch_rechub.models.ranking import DeepFM
|
|
126
|
+
from torch_rechub.trainers import CTRTrainer
|
|
127
|
+
from torch_rechub.utils.data import DataGenerator
|
|
128
|
+
|
|
129
|
+
dg = DataGenerator(x, y)
|
|
130
|
+
train_dataloader, val_dataloader, test_dataloader = dg.generate_dataloader(split_ratio=[0.7, 0.1], batch_size=256)
|
|
131
|
+
|
|
132
|
+
model = DeepFM(deep_features=deep_features, fm_features=fm_features, mlp_params={"dims": [256, 128], "dropout": 0.2, "activation": "relu"})
|
|
133
|
+
|
|
134
|
+
ctr_trainer = CTRTrainer(model)
|
|
135
|
+
ctr_trainer.fit(train_dataloader, val_dataloader)
|
|
136
|
+
auc = ctr_trainer.evaluate(ctr_trainer.model, test_dataloader)
|
|
137
|
+
```
|
|
138
|
+
|
|
139
|
+
### 多任务排序
|
|
140
|
+
|
|
141
|
+
```python
|
|
142
|
+
from torch_rechub.models.multi_task import SharedBottom, ESMM, MMOE, PLE, AITM
|
|
143
|
+
from torch_rechub.trainers import MTLTrainer
|
|
144
|
+
|
|
145
|
+
task_types = ["classification", "classification"]
|
|
146
|
+
model = MMOE(features, task_types, 8, expert_params={"dims": [32,16]}, tower_params_list=[{"dims": [32, 16]}, {"dims": [32, 16]}])
|
|
147
|
+
|
|
148
|
+
mtl_trainer = MTLTrainer(model)
|
|
149
|
+
mtl_trainer.fit(train_dataloader, val_dataloader)
|
|
150
|
+
auc = ctr_trainer.evaluate(ctr_trainer.model, test_dataloader)
|
|
151
|
+
```
|
|
152
|
+
|
|
153
|
+
### 召回模型
|
|
154
|
+
|
|
155
|
+
```python
|
|
156
|
+
from torch_rechub.models.matching import DSSM
|
|
157
|
+
from torch_rechub.trainers import MatchTrainer
|
|
158
|
+
from torch_rechub.utils.data import MatchDataGenerator
|
|
159
|
+
|
|
160
|
+
dg = MatchDataGenerator(x y)
|
|
161
|
+
train_dl, test_dl, item_dl = dg.generate_dataloader(test_user, all_item, batch_size=256)
|
|
162
|
+
|
|
163
|
+
model = DSSM(user_features, item_features, temperature=0.02,
|
|
164
|
+
user_params={
|
|
165
|
+
"dims": [256, 128, 64],
|
|
166
|
+
"activation": 'prelu',
|
|
167
|
+
},
|
|
168
|
+
item_params={
|
|
169
|
+
"dims": [256, 128, 64],
|
|
170
|
+
"activation": 'prelu',
|
|
171
|
+
})
|
|
172
|
+
|
|
173
|
+
match_trainer = MatchTrainer(model)
|
|
174
|
+
match_trainer.fit(train_dl)
|
|
175
|
+
|
|
176
|
+
```
|
|
177
|
+
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
torch_rechub/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
torch_rechub/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
+
torch_rechub/basic/activation.py,sha256=MjqWIXKhOygkwxVpuCNPaqgmP_gB13tHBgANyFoPh8k,1658
|
|
4
|
+
torch_rechub/basic/callback.py,sha256=GzAOfW-xWOJNdZhxE21crFX6-GbKinYApdPtdz4WwNQ,985
|
|
5
|
+
torch_rechub/basic/features.py,sha256=3SqSQ0uRj6lwr3j7OXIlLeS05jt5v9nlkkgilWfALL0,3777
|
|
6
|
+
torch_rechub/basic/initializers.py,sha256=AfJd0mqQqGqzk12IC_N2qZgRzz4LyP8n9FcmWzWTmH4,3292
|
|
7
|
+
torch_rechub/basic/layers.py,sha256=rOdbvpPxp3GbQ4lgZSZgWeOuhbDgjVOpHe-rbha_fyk,30484
|
|
8
|
+
torch_rechub/basic/loss_func.py,sha256=EeqYUClS3CVv3EEx2XnzK0vWAULZAw_Lr0mjf9XN6U4,1170
|
|
9
|
+
torch_rechub/basic/metaoptimizer.py,sha256=AmG-LrkDKhSnI2kK7qzVVJS0eQ3aMWHS1nEqeRXWFgU,3138
|
|
10
|
+
torch_rechub/basic/metric.py,sha256=m8fKjRVja6PyllV5OrHXwbpNntK73jFv-QzA94pNAIU,7508
|
|
11
|
+
torch_rechub/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
|
+
torch_rechub/models/matching/__init__.py,sha256=gR9csAI1ataaIVRqdlBSpsEYmLotq3UNc-YUV9PvAf4,336
|
|
13
|
+
torch_rechub/models/matching/comirec.py,sha256=-i3bb9ovcyEtGl8wqpLVo0EIwADjd3Ui7jn3b3pX7ps,9620
|
|
14
|
+
torch_rechub/models/matching/dssm.py,sha256=FuWhH-lGtqAfe6vm_Pi1K0l-wLkqTVYwtIxFrwISsXo,3042
|
|
15
|
+
torch_rechub/models/matching/dssm_facebook.py,sha256=n8nqHk2RFm_qUuBEqiu5leMSFj6giGibanru0yPk7Rk,3679
|
|
16
|
+
torch_rechub/models/matching/dssm_senet.py,sha256=AhY298J8kHjtubmJkHht8QpXnH5QC1ArPrDBYA2q4KE,3967
|
|
17
|
+
torch_rechub/models/matching/gru4rec.py,sha256=RzbVrTV-LkLB8g7vwpsUXd7viyz6ZPf0BBv3IK1bs_I,4136
|
|
18
|
+
torch_rechub/models/matching/mind.py,sha256=pnlvqYXU5sKYpzIs-ktmMBFTEUgNIfNjoXXgurFbwz4,5029
|
|
19
|
+
torch_rechub/models/matching/narm.py,sha256=UTsIuCSoMPzVDlnB6eOF-vEF5SYV-VET9pMalFYDT4I,3244
|
|
20
|
+
torch_rechub/models/matching/sasrec.py,sha256=9_4vO9Kz25Wj19i0SPsOfhbQWMJKtGQneNRMUQV0w6Y,5865
|
|
21
|
+
torch_rechub/models/matching/sine.py,sha256=uglxLI4r5JU0CnFdKZdSPcnkvmeE6b9vr2N5sBoP_qw,7089
|
|
22
|
+
torch_rechub/models/matching/stamp.py,sha256=8N7Zl527QvqARtrwf8niP86djnMy4YYqT_EPp-7UhTY,3556
|
|
23
|
+
torch_rechub/models/matching/youtube_dnn.py,sha256=ZIQY4nQFRi73QWYVALgFaApaiJQCxLQoeHzk0O2BBxg,3453
|
|
24
|
+
torch_rechub/models/matching/youtube_sbc.py,sha256=L5xfgxEvDTlBmK09GNRXiZVfFAf4v2WyKltRE18CVX8,4945
|
|
25
|
+
torch_rechub/models/multi_task/__init__.py,sha256=1dAe4mo8kRjsCAUx6Cs3jzw-PrlI2y0B5q6BK-THgZM,133
|
|
26
|
+
torch_rechub/models/multi_task/aitm.py,sha256=dcXDBXwYEHNlBlGv-gbY6lZ6Nba3Wi6iqYuYP1SGR40,3454
|
|
27
|
+
torch_rechub/models/multi_task/esmm.py,sha256=WYUiuhLxeBx9QpgJS6RawPvz9n0WWff0q7jEX4xnbDI,2789
|
|
28
|
+
torch_rechub/models/multi_task/mmoe.py,sha256=e3_0tO9dbH831y0kRdpy8hDQhFSpM7gI9GTspYCAZeA,2926
|
|
29
|
+
torch_rechub/models/multi_task/ple.py,sha256=3S5ufr320BOl41hUYxez6tKPcKKQoHq7yW1j7q-tNQk,7005
|
|
30
|
+
torch_rechub/models/multi_task/shared_bottom.py,sha256=EgLIF1Hb6u-S83pLeX-DlXAbZWu92KSS4kTl5CKQllw,1913
|
|
31
|
+
torch_rechub/models/ranking/__init__.py,sha256=Hj9_4RghDB0yzc8EGP9wwBb7VMpgogQN71NCtcOj6MI,293
|
|
32
|
+
torch_rechub/models/ranking/afm.py,sha256=HK1ww6g38vKEhEj9M0UPEVnTQ_OQKggv5qxxkrpjlLA,2323
|
|
33
|
+
torch_rechub/models/ranking/bst.py,sha256=mSISHGSYsNwjdPDOsxyXPuDXsR1qxYrVNUZBx03XDw4,3815
|
|
34
|
+
torch_rechub/models/ranking/dcn.py,sha256=OB0seU6YZRV9nGSSpV3aZWTWSXJlY0f45sfBtleaMV0,1339
|
|
35
|
+
torch_rechub/models/ranking/dcn_v2.py,sha256=yUAncBAEAEjkzPbZFKDWoSQFa4Ogo48qkHVnpprB8JQ,3182
|
|
36
|
+
torch_rechub/models/ranking/deepffm.py,sha256=mrA-tOHhNsqoaO1lgd6mg-nVMNScOt8fW1dKOjYyeT0,6429
|
|
37
|
+
torch_rechub/models/ranking/deepfm.py,sha256=g_RbTY0oWz6Bu1XyU6B0BXLznXKroIVjFrOzw5InVXU,1815
|
|
38
|
+
torch_rechub/models/ranking/dien.py,sha256=r0Wjm3MINAI3q0E82mQIGN1Q8Y6B4ZA9-uax6xfvjEg,9164
|
|
39
|
+
torch_rechub/models/ranking/din.py,sha256=mSmohL7DHlqppo8aSNFYKA8FmpBww8Cxjn-U476Ag7I,4715
|
|
40
|
+
torch_rechub/models/ranking/edcn.py,sha256=XJQ9vSYRgQMyJlGFAJH95OfIl5UAzvinmvtBYD-sHCU,5673
|
|
41
|
+
torch_rechub/models/ranking/fibinet.py,sha256=dLfDGdBbhlgz5JmgzBCKRrvYHgkAjP3VfcMsa5Sv3Tg,2294
|
|
42
|
+
torch_rechub/models/ranking/widedeep.py,sha256=tSFqeTsg-vjz-6_Av18GZqRVh3NBs7HlumE8-uV9-88,1675
|
|
43
|
+
torch_rechub/trainers/__init__.py,sha256=TPdNJp7UjHrmJE5Jq3YfM0OZ71Kfd6tDmBS-HhTdtxU,113
|
|
44
|
+
torch_rechub/trainers/ctr_trainer.py,sha256=Hd87UEV4cOOMTvXJnEp3ZXFXAqryTpjgp0TSSFpwcM8,6026
|
|
45
|
+
torch_rechub/trainers/match_trainer.py,sha256=Q8OiwCQKm8-aBrCy7qbgII6zOk38qZaJetiis2oPr-0,7988
|
|
46
|
+
torch_rechub/trainers/mtl_trainer.py,sha256=tOKO7iQ0CJ4M_AAwzjnH5m6dizWrhdED41oRub6ueBY,10006
|
|
47
|
+
torch_rechub/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
48
|
+
torch_rechub/utils/data.py,sha256=vuRigwbrpXtEBn47zwzqWTtSC4FXihOTk-4hPEUp0OU,14787
|
|
49
|
+
torch_rechub/utils/match.py,sha256=diXjN3mDaRSDc24oJpZxWPaXL6pkR6iAaiurfzmJdG8,11908
|
|
50
|
+
torch_rechub/utils/mtl.py,sha256=JpRobRgNyCFtDkVVhbQkLdMT4pFrPihJ1cm4QJogddE,5911
|
|
51
|
+
torch_rechub-0.0.3.dist-info/LICENSE,sha256=B_RpxT4MtHbioKj8jkv1yhBWs4Zlcd8I4vMS8GH5Gwg,1087
|
|
52
|
+
torch_rechub-0.0.3.dist-info/METADATA,sha256=i8JZFLp3YQVZY-awer1MLWwSCJAmJAgJof-ylPptGpI,5833
|
|
53
|
+
torch_rechub-0.0.3.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
|
54
|
+
torch_rechub-0.0.3.dist-info/top_level.txt,sha256=_RPvAnlLHdll9u2d74aFr_oOAj_NxgfLpH02Uifz_YY,13
|
|
55
|
+
torch_rechub-0.0.3.dist-info/RECORD,,
|
torch_rechub/basic/utils.py
DELETED
|
@@ -1,168 +0,0 @@
|
|
|
1
|
-
import random
|
|
2
|
-
import torch
|
|
3
|
-
import numpy as np
|
|
4
|
-
import pandas as pd
|
|
5
|
-
from torch.utils.data import Dataset, DataLoader, random_split
|
|
6
|
-
from sklearn.metrics import roc_auc_score, mean_squared_error
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class TorchDataset(Dataset):
|
|
10
|
-
|
|
11
|
-
def __init__(self, x, y):
|
|
12
|
-
super(TorchDataset, self).__init__()
|
|
13
|
-
self.x = x
|
|
14
|
-
self.y = y
|
|
15
|
-
|
|
16
|
-
def __getitem__(self, index):
|
|
17
|
-
return {k: v[index] for k, v in self.x.items()}, self.y[index]
|
|
18
|
-
|
|
19
|
-
def __len__(self):
|
|
20
|
-
return len(self.y)
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class DataGenerator(object):
|
|
24
|
-
|
|
25
|
-
def __init__(self, x, y):
|
|
26
|
-
super(DataGenerator, self).__init__()
|
|
27
|
-
self.dataset = TorchDataset(x, y)
|
|
28
|
-
self.length = len(self.dataset)
|
|
29
|
-
|
|
30
|
-
def generate_dataloader(self, x_val=None, y_val=None, x_test=None, y_test=None, split_ratio=None, batch_size=16, num_workers=8):
|
|
31
|
-
if split_ratio != None:
|
|
32
|
-
train_length = int(self.length * split_ratio[0])
|
|
33
|
-
val_length = int(self.length * split_ratio[1])
|
|
34
|
-
test_length = self.length - train_length - val_length
|
|
35
|
-
print("the samples of train : val : test are %d : %d : %d" % (train_length, val_length, test_length))
|
|
36
|
-
train_dataset, val_dataset, test_dataset = random_split(self.dataset, (train_length, val_length, test_length))
|
|
37
|
-
else:
|
|
38
|
-
train_dataset = self.dataset
|
|
39
|
-
val_dataset = TorchDataset(x_val, y_val)
|
|
40
|
-
test_dataset = TorchDataset(x_test, y_test)
|
|
41
|
-
|
|
42
|
-
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers)
|
|
43
|
-
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers)
|
|
44
|
-
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers)
|
|
45
|
-
return train_dataloader, val_dataloader, test_dataloader
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
class PredictDataset(Dataset):
|
|
49
|
-
|
|
50
|
-
def __init__(self, x):
|
|
51
|
-
super(TorchDataset, self).__init__()
|
|
52
|
-
self.x = x
|
|
53
|
-
|
|
54
|
-
def __getitem__(self, index):
|
|
55
|
-
return {k: v[index] for k, v in self.x.items()}
|
|
56
|
-
|
|
57
|
-
def __len__(self):
|
|
58
|
-
return len(self.x[list(self.x.keys())[0]])
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
def get_auto_embedding_dim(num_classes):
|
|
62
|
-
""" Calculate the dim of embedding vector according to number of classes in the category
|
|
63
|
-
emb_dim = [6 * (num_classes)^(1/4)]
|
|
64
|
-
reference: Deep & Cross Network for Ad Click Predictions.(ADKDD'17)
|
|
65
|
-
|
|
66
|
-
Args:
|
|
67
|
-
num_classes: number of classes in the category
|
|
68
|
-
|
|
69
|
-
Returns:
|
|
70
|
-
the dim of embedding vector
|
|
71
|
-
"""
|
|
72
|
-
return np.floor(6 * np.pow(num_classes, 0.26))
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
def get_loss_func(task_type="classification"):
|
|
76
|
-
if task_type == "classification":
|
|
77
|
-
return torch.nn.BCELoss()
|
|
78
|
-
elif task_type == "regression":
|
|
79
|
-
return torch.nn.MSELoss()
|
|
80
|
-
else:
|
|
81
|
-
raise ValueError("task_type must be classification or regression")
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
def get_metric_func(task_type="classification"):
|
|
85
|
-
if task_type == "classification":
|
|
86
|
-
return roc_auc_score
|
|
87
|
-
elif task_type == "regression":
|
|
88
|
-
return mean_squared_error
|
|
89
|
-
else:
|
|
90
|
-
raise ValueError("task_type must be classification or regression")
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
def create_seq_features(data, max_len=50, drop_short=3, shuffle=True):
|
|
94
|
-
"""Build a sequence of user's history by time.
|
|
95
|
-
|
|
96
|
-
Args:
|
|
97
|
-
data (pd.DataFrame): must contain keys: `user_id, item_id, cate_id, time`.
|
|
98
|
-
max_len (int): the max length of a user history sequence.
|
|
99
|
-
drop_short (int): remove some inactive user who's sequence length < drop_short.
|
|
100
|
-
shuffle (bool): shuffle data if true.
|
|
101
|
-
|
|
102
|
-
Returns:
|
|
103
|
-
train (pd.DataFrame): target item will be each item before last two items.
|
|
104
|
-
val (pd.DataFrame): target item is the second to last item of user's history sequence.
|
|
105
|
-
test (pd.DataFrame): target item is the last item of user's history sequence.
|
|
106
|
-
"""
|
|
107
|
-
n_users, n_items, n_cates = data["user_id"].max(), data["item_id"].max(), data["cate_id"].max()
|
|
108
|
-
# 0 to be used as the symbol for padding
|
|
109
|
-
data = data.astype('int32')
|
|
110
|
-
data['item_id'] = data['item_id'].apply(lambda x: x + 1)
|
|
111
|
-
data['cate_id'] = data['cate_id'].apply(lambda x: x + 1)
|
|
112
|
-
|
|
113
|
-
item_cate_map = data[['item_id', 'cate_id']]
|
|
114
|
-
item2cate_dict = item_cate_map.set_index(['item_id'])['cate_id'].to_dict()
|
|
115
|
-
|
|
116
|
-
data = data.sort_values(['user_id', 'time']).groupby('user_id').agg(click_hist_list=('item_id', list), cate_hist_hist=('cate_id', list)).reset_index()
|
|
117
|
-
|
|
118
|
-
# Sliding window to construct negative samples
|
|
119
|
-
train_data, val_data, test_data = [], [], []
|
|
120
|
-
for item in data.itertuples():
|
|
121
|
-
if len(item[2]) < drop_short:
|
|
122
|
-
continue
|
|
123
|
-
click_hist_list = item[2][:max_len]
|
|
124
|
-
cate_hist_list = item[3][:max_len]
|
|
125
|
-
|
|
126
|
-
def neg_sample():
|
|
127
|
-
neg = click_hist_list[0]
|
|
128
|
-
while neg in click_hist_list:
|
|
129
|
-
neg = random.randint(1, n_items)
|
|
130
|
-
return neg
|
|
131
|
-
|
|
132
|
-
neg_list = [neg_sample() for _ in range(len(click_hist_list))]
|
|
133
|
-
hist_list = []
|
|
134
|
-
cate_list = []
|
|
135
|
-
for i in range(1, len(click_hist_list)):
|
|
136
|
-
hist_list.append(click_hist_list[i - 1])
|
|
137
|
-
cate_list.append(cate_hist_list[i - 1])
|
|
138
|
-
hist_list_pad = hist_list + [0] * (max_len - len(hist_list))
|
|
139
|
-
cate_list_pad = cate_list + [0] * (max_len - len(cate_list))
|
|
140
|
-
if i == len(click_hist_list) - 1:
|
|
141
|
-
test_data.append([hist_list_pad, cate_list_pad, click_hist_list[i], cate_hist_list[i], 1])
|
|
142
|
-
test_data.append([hist_list_pad, cate_list_pad, neg_list[i], item2cate_dict[neg_list[i]], 0])
|
|
143
|
-
if i == len(click_hist_list) - 2:
|
|
144
|
-
val_data.append([hist_list_pad, cate_list_pad, click_hist_list[i], cate_hist_list[i], 1])
|
|
145
|
-
val_data.append([hist_list_pad, cate_list_pad, neg_list[i], item2cate_dict[neg_list[i]], 0])
|
|
146
|
-
else:
|
|
147
|
-
train_data.append([hist_list_pad, cate_list_pad, click_hist_list[i], cate_hist_list[i], 1])
|
|
148
|
-
train_data.append([hist_list_pad, cate_list_pad, neg_list[i], item2cate_dict[neg_list[i]], 0])
|
|
149
|
-
|
|
150
|
-
# shuffle
|
|
151
|
-
if shuffle:
|
|
152
|
-
random.shuffle(train_data)
|
|
153
|
-
random.shuffle(val_data)
|
|
154
|
-
random.shuffle(test_data)
|
|
155
|
-
|
|
156
|
-
col_name = ['history_item', 'history_cate', 'target_item', 'target_cate', 'label']
|
|
157
|
-
train = pd.DataFrame(train_data, columns=col_name)
|
|
158
|
-
val = pd.DataFrame(val_data, columns=col_name)
|
|
159
|
-
test = pd.DataFrame(test_data, columns=col_name)
|
|
160
|
-
|
|
161
|
-
return train, val, test
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
def df_to_input_dict(data):
|
|
165
|
-
data_dict = data.to_dict('list')
|
|
166
|
-
for key in data.keys():
|
|
167
|
-
data_dict[key] = np.array(data_dict[key])
|
|
168
|
-
return data_dict
|