torch-rechub 0.0.1__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/__init__.py +14 -0
- torch_rechub/basic/activation.py +3 -1
- torch_rechub/basic/callback.py +2 -2
- torch_rechub/basic/features.py +38 -8
- torch_rechub/basic/initializers.py +92 -0
- torch_rechub/basic/layers.py +800 -46
- torch_rechub/basic/loss_func.py +223 -0
- torch_rechub/basic/metaoptimizer.py +76 -0
- torch_rechub/basic/metric.py +251 -0
- torch_rechub/models/generative/__init__.py +6 -0
- torch_rechub/models/generative/hllm.py +249 -0
- torch_rechub/models/generative/hstu.py +189 -0
- torch_rechub/models/matching/__init__.py +13 -0
- torch_rechub/models/matching/comirec.py +193 -0
- torch_rechub/models/matching/dssm.py +72 -0
- torch_rechub/models/matching/dssm_facebook.py +77 -0
- torch_rechub/models/matching/dssm_senet.py +87 -0
- torch_rechub/models/matching/gru4rec.py +85 -0
- torch_rechub/models/matching/mind.py +103 -0
- torch_rechub/models/matching/narm.py +82 -0
- torch_rechub/models/matching/sasrec.py +143 -0
- torch_rechub/models/matching/sine.py +148 -0
- torch_rechub/models/matching/stamp.py +81 -0
- torch_rechub/models/matching/youtube_dnn.py +75 -0
- torch_rechub/models/matching/youtube_sbc.py +98 -0
- torch_rechub/models/multi_task/__init__.py +5 -2
- torch_rechub/models/multi_task/aitm.py +83 -0
- torch_rechub/models/multi_task/esmm.py +19 -8
- torch_rechub/models/multi_task/mmoe.py +18 -12
- torch_rechub/models/multi_task/ple.py +41 -29
- torch_rechub/models/multi_task/shared_bottom.py +3 -2
- torch_rechub/models/ranking/__init__.py +13 -2
- torch_rechub/models/ranking/afm.py +65 -0
- torch_rechub/models/ranking/autoint.py +102 -0
- torch_rechub/models/ranking/bst.py +61 -0
- torch_rechub/models/ranking/dcn.py +38 -0
- torch_rechub/models/ranking/dcn_v2.py +59 -0
- torch_rechub/models/ranking/deepffm.py +131 -0
- torch_rechub/models/ranking/deepfm.py +8 -7
- torch_rechub/models/ranking/dien.py +191 -0
- torch_rechub/models/ranking/din.py +31 -19
- torch_rechub/models/ranking/edcn.py +101 -0
- torch_rechub/models/ranking/fibinet.py +42 -0
- torch_rechub/models/ranking/widedeep.py +6 -6
- torch_rechub/trainers/__init__.py +4 -2
- torch_rechub/trainers/ctr_trainer.py +191 -0
- torch_rechub/trainers/match_trainer.py +239 -0
- torch_rechub/trainers/matching.md +3 -0
- torch_rechub/trainers/mtl_trainer.py +137 -23
- torch_rechub/trainers/seq_trainer.py +293 -0
- torch_rechub/utils/__init__.py +0 -0
- torch_rechub/utils/data.py +492 -0
- torch_rechub/utils/hstu_utils.py +198 -0
- torch_rechub/utils/match.py +457 -0
- torch_rechub/utils/mtl.py +136 -0
- torch_rechub/utils/onnx_export.py +353 -0
- torch_rechub-0.0.4.dist-info/METADATA +391 -0
- torch_rechub-0.0.4.dist-info/RECORD +62 -0
- {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.4.dist-info}/WHEEL +1 -2
- {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.4.dist-info/licenses}/LICENSE +1 -1
- torch_rechub/basic/utils.py +0 -168
- torch_rechub/trainers/trainer.py +0 -111
- torch_rechub-0.0.1.dist-info/METADATA +0 -105
- torch_rechub-0.0.1.dist-info/RECORD +0 -26
- torch_rechub-0.0.1.dist-info/top_level.txt +0 -1
torch_rechub/trainers/trainer.py
DELETED
|
@@ -1,111 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import torch
|
|
3
|
-
import tqdm
|
|
4
|
-
from sklearn.metrics import roc_auc_score
|
|
5
|
-
from ..basic.callback import EarlyStopper
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
class CTRTrainer(object):
|
|
9
|
-
"""A general trainer for single task learning.
|
|
10
|
-
|
|
11
|
-
Args:
|
|
12
|
-
model (nn.Module): any multi task learning model.
|
|
13
|
-
optimizer_fn (torch.optim): optimizer function of pytorch (default = `torch.optim.Adam`).
|
|
14
|
-
optimizer_params (dict): parameters of optimizer_fn.
|
|
15
|
-
scheduler_fn (torch.optim.lr_scheduler) : torch scheduling class, eg. `torch.optim.lr_scheduler.StepLR`.
|
|
16
|
-
scheduler_params (dict): parameters of optimizer scheduler_fn.
|
|
17
|
-
n_epoch (int): epoch number of training.
|
|
18
|
-
earlystop_patience (int): how long to wait after last time validation auc improved (default=10).
|
|
19
|
-
device (str): `"cpu"` or `"cuda:0"`
|
|
20
|
-
gpus (list): id of multi gpu (default=[]). If the length >=1, then the model will wrapped by nn.DataParallel.
|
|
21
|
-
model_path (str): the path you want to save the model (default="./"). Note only save the best weight in the validation data.
|
|
22
|
-
"""
|
|
23
|
-
|
|
24
|
-
def __init__(
|
|
25
|
-
self,
|
|
26
|
-
model,
|
|
27
|
-
optimizer_fn=torch.optim.Adam,
|
|
28
|
-
optimizer_params={
|
|
29
|
-
"lr": 1e-3,
|
|
30
|
-
"weight_decay": 1e-5
|
|
31
|
-
},
|
|
32
|
-
scheduler_fn=None,
|
|
33
|
-
scheduler_params=None,
|
|
34
|
-
n_epoch=10,
|
|
35
|
-
earlystop_patience=10,
|
|
36
|
-
device="cpu",
|
|
37
|
-
gpus=[],
|
|
38
|
-
model_path="./",
|
|
39
|
-
):
|
|
40
|
-
self.model = model #for uniform weights save method in one gpu or multi gpu
|
|
41
|
-
self.optimizer = optimizer_fn(self.model.parameters(), **optimizer_params) #default optimizer
|
|
42
|
-
self.scheduler = None
|
|
43
|
-
if scheduler_fn is not None:
|
|
44
|
-
self.scheduler = scheduler_fn(self.optimizer, **scheduler_params)
|
|
45
|
-
self.criterion = torch.nn.BCELoss() #default loss cross_entropy
|
|
46
|
-
self.evaluate_fn = roc_auc_score #default evaluate function
|
|
47
|
-
self.n_epoch = n_epoch
|
|
48
|
-
self.early_stopper = EarlyStopper(patience=earlystop_patience)
|
|
49
|
-
self.device = torch.device(device) #torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
50
|
-
self.gpus = gpus
|
|
51
|
-
if len(gpus) > 1:
|
|
52
|
-
print('parallel running on these gpus:', gpus)
|
|
53
|
-
self.model = torch.nn.DataParallel(self.model, device_ids=gpus)
|
|
54
|
-
self.model_path = model_path
|
|
55
|
-
|
|
56
|
-
def train_one_epoch(self, data_loader, log_interval=10):
|
|
57
|
-
self.model.train()
|
|
58
|
-
total_loss = 0
|
|
59
|
-
tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
|
|
60
|
-
for i, (x_dict, y) in enumerate(tk0):
|
|
61
|
-
x_dict = {k: v.to(self.device) for k, v in x_dict.items()} #tensor to GPU
|
|
62
|
-
y = y.to(self.device)
|
|
63
|
-
y_pred = self.model(x_dict)
|
|
64
|
-
loss = self.criterion(y_pred, y.float())
|
|
65
|
-
self.model.zero_grad()
|
|
66
|
-
loss.backward()
|
|
67
|
-
self.optimizer.step()
|
|
68
|
-
total_loss += loss.item()
|
|
69
|
-
if (i + 1) % log_interval == 0:
|
|
70
|
-
tk0.set_postfix(loss=total_loss / log_interval)
|
|
71
|
-
total_loss = 0
|
|
72
|
-
|
|
73
|
-
def fit(self, train_dataloader, val_dataloader):
|
|
74
|
-
self.model.to(self.device)
|
|
75
|
-
for epoch_i in range(self.n_epoch):
|
|
76
|
-
self.train_one_epoch(train_dataloader)
|
|
77
|
-
if self.scheduler is not None:
|
|
78
|
-
if epoch_i % self.scheduler.step_size == 0:
|
|
79
|
-
print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
|
|
80
|
-
self.scheduler.step() #update lr in epoch level by scheduler
|
|
81
|
-
auc = self.evaluate(self.model, val_dataloader)
|
|
82
|
-
print('epoch:', epoch_i, 'validation: auc:', auc)
|
|
83
|
-
if self.early_stopper.stop_training(auc, self.model.state_dict()):
|
|
84
|
-
print(f'validation: best auc: {self.early_stopper.best_auc}')
|
|
85
|
-
self.model.load_state_dict(self.early_stopper.best_weights)
|
|
86
|
-
torch.save(self.early_stopper.best_weights, os.path.join(self.model_path, "model.pth")) #save best auc model
|
|
87
|
-
break
|
|
88
|
-
|
|
89
|
-
def evaluate(self, model, data_loader):
|
|
90
|
-
model.eval()
|
|
91
|
-
targets, predicts = list(), list()
|
|
92
|
-
with torch.no_grad():
|
|
93
|
-
tk0 = tqdm.tqdm(data_loader, desc="validation", smoothing=0, mininterval=1.0)
|
|
94
|
-
for i, (x_dict, y) in enumerate(tk0):
|
|
95
|
-
x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
|
|
96
|
-
y = y.to(self.device)
|
|
97
|
-
y_pred = model(x_dict)
|
|
98
|
-
targets.extend(y.tolist())
|
|
99
|
-
predicts.extend(y_pred.tolist())
|
|
100
|
-
return self.evaluate_fn(targets, predicts)
|
|
101
|
-
|
|
102
|
-
def predict(self, model, data_loader):
|
|
103
|
-
model.eval()
|
|
104
|
-
predicts = list()
|
|
105
|
-
with torch.no_grad():
|
|
106
|
-
tk0 = tqdm.tqdm(data_loader, desc="predict", smoothing=0, mininterval=1.0)
|
|
107
|
-
for i, x_dict in enumerate(tk0):
|
|
108
|
-
x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
|
|
109
|
-
y_pred = model(x_dict)
|
|
110
|
-
predicts.extend(y_pred.tolist())
|
|
111
|
-
return predicts
|
|
@@ -1,105 +0,0 @@
|
|
|
1
|
-
Metadata-Version: 2.1
|
|
2
|
-
Name: torch-rechub
|
|
3
|
-
Version: 0.0.1
|
|
4
|
-
Summary: A Lighting Pytorch Framework for Recommendation System, Easy-to-use and Easy-to-extend.
|
|
5
|
-
Home-page: https://github.com/morningsky/Torch-RecHub
|
|
6
|
-
Author: Mincai lai
|
|
7
|
-
Author-email: 757387961@qq.com
|
|
8
|
-
License: UNKNOWN
|
|
9
|
-
Keywords: ctr,click through rate,deep learning,pytorch,recsys,recommendation
|
|
10
|
-
Platform: all
|
|
11
|
-
Classifier: Intended Audience :: Developers
|
|
12
|
-
Classifier: Intended Audience :: Education
|
|
13
|
-
Classifier: Intended Audience :: Science/Research
|
|
14
|
-
Classifier: Operating System :: OS Independent
|
|
15
|
-
Classifier: Programming Language :: Python :: 3
|
|
16
|
-
Classifier: Programming Language :: Python :: 3.7
|
|
17
|
-
Classifier: Programming Language :: Python :: 3.8
|
|
18
|
-
Classifier: Programming Language :: Python :: 3.9
|
|
19
|
-
Classifier: Topic :: Scientific/Engineering
|
|
20
|
-
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
21
|
-
Classifier: Topic :: Software Development :: Libraries
|
|
22
|
-
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
23
|
-
Classifier: License :: OSI Approved :: MIT License
|
|
24
|
-
Description-Content-Type: text/markdown
|
|
25
|
-
License-File: LICENSE
|
|
26
|
-
Requires-Dist: numpy (>=1.21.5)
|
|
27
|
-
Requires-Dist: torch (>=1.7.0)
|
|
28
|
-
Requires-Dist: pandas (>=1.0.5)
|
|
29
|
-
Requires-Dist: tqdm (>=4.64.0)
|
|
30
|
-
Requires-Dist: scikit-learn (>=0.23.2)
|
|
31
|
-
|
|
32
|
-
# Torch-RecHub
|
|
33
|
-
|
|
34
|
-
A Lighting Pytorch Framework for Recommendation Models, Easy-to-use and Easy-to-extend.
|
|
35
|
-
|
|
36
|
-
## 安装
|
|
37
|
-
|
|
38
|
-
```python
|
|
39
|
-
pip install torch-rechub
|
|
40
|
-
```
|
|
41
|
-
|
|
42
|
-
## 主要特性
|
|
43
|
-
|
|
44
|
-
- scikit-learn风格易用的API(fit、predict),即插即用
|
|
45
|
-
- 训练过程与模型定义解耦,易拓展,可针对不同类型的模型设置不同的训练机制
|
|
46
|
-
- 使用Pytorch原生Dataset、DataLoader,易修改,自定义数据
|
|
47
|
-
- 高度模块化,支持常见Layer(MLP、FM、FFM、target-attention、self-attention、transformer等),容易调用组装成新模型
|
|
48
|
-
- 支持常见排序模型(WideDeep、DeepFM、DIN、DCN、xDeepFM等)
|
|
49
|
-
|
|
50
|
-
- [ ] 支持常见召回模型(DSSM、YoutubeDNN、MIND、SARSRec等)
|
|
51
|
-
- 丰富的多任务学习支持
|
|
52
|
-
- SharedBottom、ESMM、MMOE、PLE、AITM等模型
|
|
53
|
-
- GradNorm、UWL等动态loss加权机制
|
|
54
|
-
|
|
55
|
-
- 聚焦更生态化的推荐场景
|
|
56
|
-
- [ ] 冷启动
|
|
57
|
-
- [ ] 延迟反馈
|
|
58
|
-
- [ ] 去偏
|
|
59
|
-
- [ ] 支持丰富的训练机制(对比学习、蒸馏学习等)
|
|
60
|
-
|
|
61
|
-
- [ ] 第三方高性能开源Trainer支持(Pytorch Lighting等)
|
|
62
|
-
- [ ] 更多模型正在开发中
|
|
63
|
-
|
|
64
|
-
## 快速使用
|
|
65
|
-
|
|
66
|
-
```python
|
|
67
|
-
from torch_rechub.models import WideDeep, DeepFM, DIN
|
|
68
|
-
from torch_rechub.trainers import CTRTrainer
|
|
69
|
-
from torch_rechub.basic.utils import DataGenerator
|
|
70
|
-
|
|
71
|
-
dg = DataGenerator(x, y)
|
|
72
|
-
train_dataloader, val_dataloader, test_dataloader = dg.generate_dataloader()
|
|
73
|
-
|
|
74
|
-
model = DeepFM(deep_features=deep_features, fm_features=fm_features, mlp_params={"dims": [256, 128], "dropout": 0.2, "activation": "relu"})
|
|
75
|
-
|
|
76
|
-
ctr_trainer = CTRTrainer(model)
|
|
77
|
-
ctr_trainer.fit(train_dataloader, val_dataloader)
|
|
78
|
-
auc = ctr_trainer.evaluate(ctr_trainer.model, test_dataloader)
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
```
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
> **Note:**
|
|
88
|
-
>
|
|
89
|
-
> 所有模型均在大多数论文提及的多个知名公开数据集中测试,达到或者接近论文性能。
|
|
90
|
-
>
|
|
91
|
-
> 使用案例:[Examples](./examples)
|
|
92
|
-
>
|
|
93
|
-
> 每个数据集将会提供
|
|
94
|
-
>
|
|
95
|
-
> - 一个使用脚本,包含样本生成、模型训练与测试,并提供一套测评所用参数。
|
|
96
|
-
> - 一个预处理脚本,将原始数据进行预处理,转化成csv。
|
|
97
|
-
> - 数据格式参考文件(100条)。
|
|
98
|
-
> - 全量数据,统一的csv文件,提供高速网盘下载链接和原始数据链接。
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
[初步规划TODO清单](https://user-images.githubusercontent.com/11856746/167436396-f9c5de5b-d341-4697-8b91-884d4ae552be.png)
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
@@ -1,26 +0,0 @@
|
|
|
1
|
-
torch_rechub/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
torch_rechub/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
-
torch_rechub/basic/activation.py,sha256=aIvomPA2X00dl3svv4lU2a4TvvltWxuqy22oML0e5YI,1519
|
|
4
|
-
torch_rechub/basic/callback.py,sha256=Z62CIrPF2axyBSTRUnbc--9ADH8oF9YLTvLtxzvsMYE,953
|
|
5
|
-
torch_rechub/basic/features.py,sha256=T4mA8I4JgCFZR0BZLrLRxKUdXa0Aaum9tJVlN3pdk_Y,1924
|
|
6
|
-
torch_rechub/basic/layers.py,sha256=VCiaK0G4asvvzH4f7TobA7tyVhzgmc3FzvFK783z5yI,9354
|
|
7
|
-
torch_rechub/basic/utils.py,sha256=v1NGrWVdj_CZBA7bERPEWqX5TcnQm0rVMDebDGXH17s,6581
|
|
8
|
-
torch_rechub/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
9
|
-
torch_rechub/models/matching/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
|
-
torch_rechub/models/multi_task/__init__.py,sha256=tfy0NEOZlSkpSv6GSM3IFxPFGTrV3lpEXPAIGfUfvOM,106
|
|
11
|
-
torch_rechub/models/multi_task/esmm.py,sha256=h7xSOxegyxfic5VFCEXEXtJLXanuRNBZ_T_rMMKZyyM,2168
|
|
12
|
-
torch_rechub/models/multi_task/mmoe.py,sha256=2O4qUcmCMUKUW2gES6R84B2NTuKulgo6V01umhXPRQk,2818
|
|
13
|
-
torch_rechub/models/multi_task/ple.py,sha256=m7k2D146Qg2ppKDtbXsPDFwXtSTjpsYO2hAFbFue7zQ,6467
|
|
14
|
-
torch_rechub/models/multi_task/shared_bottom.py,sha256=3sBN5oaM912vJ1jzCpzaScnILIxobIA566cvQGS-TWQ,1855
|
|
15
|
-
torch_rechub/models/ranking/__init__.py,sha256=k8fO0qs0WdbMcbOft2aV7NhXkmOAmBVTrQ6I1IMX7q4,78
|
|
16
|
-
torch_rechub/models/ranking/deepfm.py,sha256=s3dAf5CYfNVHI62wXR2CZujmZ9V3yBdOBG6Y1ZgKOQc,1774
|
|
17
|
-
torch_rechub/models/ranking/din.py,sha256=7UxS4ufcAi8ladJY68d3WEbn4986zHre8LaB4s9OVSs,4500
|
|
18
|
-
torch_rechub/models/ranking/widedeep.py,sha256=ZpjhmIrrLygThUNP4ZWplgSboizrQjioEdy3SmGt3pM,1634
|
|
19
|
-
torch_rechub/trainers/__init__.py,sha256=-HYw87_OnPJtwyDSUFioBNDJ7aiMM-bZNmiqV9vbfpw,67
|
|
20
|
-
torch_rechub/trainers/mtl_trainer.py,sha256=yD98xeQuaI6cY6sN3ytDqrl4OQX113UA2kstDLC1rT4,6942
|
|
21
|
-
torch_rechub/trainers/trainer.py,sha256=J7GhFed1QzBsUKntJD8qA2gMEtxAnx7haznpiousIk0,5052
|
|
22
|
-
torch_rechub-0.0.1.dist-info/LICENSE,sha256=uz-kPWsf6PAphHAysV_p4A_MUcogaJbNrjPlINUC1x4,1067
|
|
23
|
-
torch_rechub-0.0.1.dist-info/METADATA,sha256=XOlqfEwVKBuTpR2uXoavaT1hWC8AksgTPNlYy8-rcrk,3622
|
|
24
|
-
torch_rechub-0.0.1.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
|
|
25
|
-
torch_rechub-0.0.1.dist-info/top_level.txt,sha256=_RPvAnlLHdll9u2d74aFr_oOAj_NxgfLpH02Uifz_YY,13
|
|
26
|
-
torch_rechub-0.0.1.dist-info/RECORD,,
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
torch_rechub
|