torch-rechub 0.0.1__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. torch_rechub/__init__.py +14 -0
  2. torch_rechub/basic/activation.py +3 -1
  3. torch_rechub/basic/callback.py +2 -2
  4. torch_rechub/basic/features.py +38 -8
  5. torch_rechub/basic/initializers.py +92 -0
  6. torch_rechub/basic/layers.py +800 -46
  7. torch_rechub/basic/loss_func.py +223 -0
  8. torch_rechub/basic/metaoptimizer.py +76 -0
  9. torch_rechub/basic/metric.py +251 -0
  10. torch_rechub/models/generative/__init__.py +6 -0
  11. torch_rechub/models/generative/hllm.py +249 -0
  12. torch_rechub/models/generative/hstu.py +189 -0
  13. torch_rechub/models/matching/__init__.py +13 -0
  14. torch_rechub/models/matching/comirec.py +193 -0
  15. torch_rechub/models/matching/dssm.py +72 -0
  16. torch_rechub/models/matching/dssm_facebook.py +77 -0
  17. torch_rechub/models/matching/dssm_senet.py +87 -0
  18. torch_rechub/models/matching/gru4rec.py +85 -0
  19. torch_rechub/models/matching/mind.py +103 -0
  20. torch_rechub/models/matching/narm.py +82 -0
  21. torch_rechub/models/matching/sasrec.py +143 -0
  22. torch_rechub/models/matching/sine.py +148 -0
  23. torch_rechub/models/matching/stamp.py +81 -0
  24. torch_rechub/models/matching/youtube_dnn.py +75 -0
  25. torch_rechub/models/matching/youtube_sbc.py +98 -0
  26. torch_rechub/models/multi_task/__init__.py +5 -2
  27. torch_rechub/models/multi_task/aitm.py +83 -0
  28. torch_rechub/models/multi_task/esmm.py +19 -8
  29. torch_rechub/models/multi_task/mmoe.py +18 -12
  30. torch_rechub/models/multi_task/ple.py +41 -29
  31. torch_rechub/models/multi_task/shared_bottom.py +3 -2
  32. torch_rechub/models/ranking/__init__.py +13 -2
  33. torch_rechub/models/ranking/afm.py +65 -0
  34. torch_rechub/models/ranking/autoint.py +102 -0
  35. torch_rechub/models/ranking/bst.py +61 -0
  36. torch_rechub/models/ranking/dcn.py +38 -0
  37. torch_rechub/models/ranking/dcn_v2.py +59 -0
  38. torch_rechub/models/ranking/deepffm.py +131 -0
  39. torch_rechub/models/ranking/deepfm.py +8 -7
  40. torch_rechub/models/ranking/dien.py +191 -0
  41. torch_rechub/models/ranking/din.py +31 -19
  42. torch_rechub/models/ranking/edcn.py +101 -0
  43. torch_rechub/models/ranking/fibinet.py +42 -0
  44. torch_rechub/models/ranking/widedeep.py +6 -6
  45. torch_rechub/trainers/__init__.py +4 -2
  46. torch_rechub/trainers/ctr_trainer.py +191 -0
  47. torch_rechub/trainers/match_trainer.py +239 -0
  48. torch_rechub/trainers/matching.md +3 -0
  49. torch_rechub/trainers/mtl_trainer.py +137 -23
  50. torch_rechub/trainers/seq_trainer.py +293 -0
  51. torch_rechub/utils/__init__.py +0 -0
  52. torch_rechub/utils/data.py +492 -0
  53. torch_rechub/utils/hstu_utils.py +198 -0
  54. torch_rechub/utils/match.py +457 -0
  55. torch_rechub/utils/mtl.py +136 -0
  56. torch_rechub/utils/onnx_export.py +353 -0
  57. torch_rechub-0.0.4.dist-info/METADATA +391 -0
  58. torch_rechub-0.0.4.dist-info/RECORD +62 -0
  59. {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.4.dist-info}/WHEEL +1 -2
  60. {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.4.dist-info/licenses}/LICENSE +1 -1
  61. torch_rechub/basic/utils.py +0 -168
  62. torch_rechub/trainers/trainer.py +0 -111
  63. torch_rechub-0.0.1.dist-info/METADATA +0 -105
  64. torch_rechub-0.0.1.dist-info/RECORD +0 -26
  65. torch_rechub-0.0.1.dist-info/top_level.txt +0 -1
@@ -1,111 +0,0 @@
1
- import os
2
- import torch
3
- import tqdm
4
- from sklearn.metrics import roc_auc_score
5
- from ..basic.callback import EarlyStopper
6
-
7
-
8
- class CTRTrainer(object):
9
- """A general trainer for single task learning.
10
-
11
- Args:
12
- model (nn.Module): any multi task learning model.
13
- optimizer_fn (torch.optim): optimizer function of pytorch (default = `torch.optim.Adam`).
14
- optimizer_params (dict): parameters of optimizer_fn.
15
- scheduler_fn (torch.optim.lr_scheduler) : torch scheduling class, eg. `torch.optim.lr_scheduler.StepLR`.
16
- scheduler_params (dict): parameters of optimizer scheduler_fn.
17
- n_epoch (int): epoch number of training.
18
- earlystop_patience (int): how long to wait after last time validation auc improved (default=10).
19
- device (str): `"cpu"` or `"cuda:0"`
20
- gpus (list): id of multi gpu (default=[]). If the length >=1, then the model will wrapped by nn.DataParallel.
21
- model_path (str): the path you want to save the model (default="./"). Note only save the best weight in the validation data.
22
- """
23
-
24
- def __init__(
25
- self,
26
- model,
27
- optimizer_fn=torch.optim.Adam,
28
- optimizer_params={
29
- "lr": 1e-3,
30
- "weight_decay": 1e-5
31
- },
32
- scheduler_fn=None,
33
- scheduler_params=None,
34
- n_epoch=10,
35
- earlystop_patience=10,
36
- device="cpu",
37
- gpus=[],
38
- model_path="./",
39
- ):
40
- self.model = model #for uniform weights save method in one gpu or multi gpu
41
- self.optimizer = optimizer_fn(self.model.parameters(), **optimizer_params) #default optimizer
42
- self.scheduler = None
43
- if scheduler_fn is not None:
44
- self.scheduler = scheduler_fn(self.optimizer, **scheduler_params)
45
- self.criterion = torch.nn.BCELoss() #default loss cross_entropy
46
- self.evaluate_fn = roc_auc_score #default evaluate function
47
- self.n_epoch = n_epoch
48
- self.early_stopper = EarlyStopper(patience=earlystop_patience)
49
- self.device = torch.device(device) #torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
50
- self.gpus = gpus
51
- if len(gpus) > 1:
52
- print('parallel running on these gpus:', gpus)
53
- self.model = torch.nn.DataParallel(self.model, device_ids=gpus)
54
- self.model_path = model_path
55
-
56
- def train_one_epoch(self, data_loader, log_interval=10):
57
- self.model.train()
58
- total_loss = 0
59
- tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
60
- for i, (x_dict, y) in enumerate(tk0):
61
- x_dict = {k: v.to(self.device) for k, v in x_dict.items()} #tensor to GPU
62
- y = y.to(self.device)
63
- y_pred = self.model(x_dict)
64
- loss = self.criterion(y_pred, y.float())
65
- self.model.zero_grad()
66
- loss.backward()
67
- self.optimizer.step()
68
- total_loss += loss.item()
69
- if (i + 1) % log_interval == 0:
70
- tk0.set_postfix(loss=total_loss / log_interval)
71
- total_loss = 0
72
-
73
- def fit(self, train_dataloader, val_dataloader):
74
- self.model.to(self.device)
75
- for epoch_i in range(self.n_epoch):
76
- self.train_one_epoch(train_dataloader)
77
- if self.scheduler is not None:
78
- if epoch_i % self.scheduler.step_size == 0:
79
- print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
80
- self.scheduler.step() #update lr in epoch level by scheduler
81
- auc = self.evaluate(self.model, val_dataloader)
82
- print('epoch:', epoch_i, 'validation: auc:', auc)
83
- if self.early_stopper.stop_training(auc, self.model.state_dict()):
84
- print(f'validation: best auc: {self.early_stopper.best_auc}')
85
- self.model.load_state_dict(self.early_stopper.best_weights)
86
- torch.save(self.early_stopper.best_weights, os.path.join(self.model_path, "model.pth")) #save best auc model
87
- break
88
-
89
- def evaluate(self, model, data_loader):
90
- model.eval()
91
- targets, predicts = list(), list()
92
- with torch.no_grad():
93
- tk0 = tqdm.tqdm(data_loader, desc="validation", smoothing=0, mininterval=1.0)
94
- for i, (x_dict, y) in enumerate(tk0):
95
- x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
96
- y = y.to(self.device)
97
- y_pred = model(x_dict)
98
- targets.extend(y.tolist())
99
- predicts.extend(y_pred.tolist())
100
- return self.evaluate_fn(targets, predicts)
101
-
102
- def predict(self, model, data_loader):
103
- model.eval()
104
- predicts = list()
105
- with torch.no_grad():
106
- tk0 = tqdm.tqdm(data_loader, desc="predict", smoothing=0, mininterval=1.0)
107
- for i, x_dict in enumerate(tk0):
108
- x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
109
- y_pred = model(x_dict)
110
- predicts.extend(y_pred.tolist())
111
- return predicts
@@ -1,105 +0,0 @@
1
- Metadata-Version: 2.1
2
- Name: torch-rechub
3
- Version: 0.0.1
4
- Summary: A Lighting Pytorch Framework for Recommendation System, Easy-to-use and Easy-to-extend.
5
- Home-page: https://github.com/morningsky/Torch-RecHub
6
- Author: Mincai lai
7
- Author-email: 757387961@qq.com
8
- License: UNKNOWN
9
- Keywords: ctr,click through rate,deep learning,pytorch,recsys,recommendation
10
- Platform: all
11
- Classifier: Intended Audience :: Developers
12
- Classifier: Intended Audience :: Education
13
- Classifier: Intended Audience :: Science/Research
14
- Classifier: Operating System :: OS Independent
15
- Classifier: Programming Language :: Python :: 3
16
- Classifier: Programming Language :: Python :: 3.7
17
- Classifier: Programming Language :: Python :: 3.8
18
- Classifier: Programming Language :: Python :: 3.9
19
- Classifier: Topic :: Scientific/Engineering
20
- Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
21
- Classifier: Topic :: Software Development :: Libraries
22
- Classifier: Topic :: Software Development :: Libraries :: Python Modules
23
- Classifier: License :: OSI Approved :: MIT License
24
- Description-Content-Type: text/markdown
25
- License-File: LICENSE
26
- Requires-Dist: numpy (>=1.21.5)
27
- Requires-Dist: torch (>=1.7.0)
28
- Requires-Dist: pandas (>=1.0.5)
29
- Requires-Dist: tqdm (>=4.64.0)
30
- Requires-Dist: scikit-learn (>=0.23.2)
31
-
32
- # Torch-RecHub
33
-
34
- A Lighting Pytorch Framework for Recommendation Models, Easy-to-use and Easy-to-extend.
35
-
36
- ## 安装
37
-
38
- ```python
39
- pip install torch-rechub
40
- ```
41
-
42
- ## 主要特性
43
-
44
- - scikit-learn风格易用的API(fit、predict),即插即用
45
- - 训练过程与模型定义解耦,易拓展,可针对不同类型的模型设置不同的训练机制
46
- - 使用Pytorch原生Dataset、DataLoader,易修改,自定义数据
47
- - 高度模块化,支持常见Layer(MLP、FM、FFM、target-attention、self-attention、transformer等),容易调用组装成新模型
48
- - 支持常见排序模型(WideDeep、DeepFM、DIN、DCN、xDeepFM等)
49
-
50
- - [ ] 支持常见召回模型(DSSM、YoutubeDNN、MIND、SARSRec等)
51
- - 丰富的多任务学习支持
52
- - SharedBottom、ESMM、MMOE、PLE、AITM等模型
53
- - GradNorm、UWL等动态loss加权机制
54
-
55
- - 聚焦更生态化的推荐场景
56
- - [ ] 冷启动
57
- - [ ] 延迟反馈
58
- - [ ] 去偏
59
- - [ ] 支持丰富的训练机制(对比学习、蒸馏学习等)
60
-
61
- - [ ] 第三方高性能开源Trainer支持(Pytorch Lighting等)
62
- - [ ] 更多模型正在开发中
63
-
64
- ## 快速使用
65
-
66
- ```python
67
- from torch_rechub.models import WideDeep, DeepFM, DIN
68
- from torch_rechub.trainers import CTRTrainer
69
- from torch_rechub.basic.utils import DataGenerator
70
-
71
- dg = DataGenerator(x, y)
72
- train_dataloader, val_dataloader, test_dataloader = dg.generate_dataloader()
73
-
74
- model = DeepFM(deep_features=deep_features, fm_features=fm_features, mlp_params={"dims": [256, 128], "dropout": 0.2, "activation": "relu"})
75
-
76
- ctr_trainer = CTRTrainer(model)
77
- ctr_trainer.fit(train_dataloader, val_dataloader)
78
- auc = ctr_trainer.evaluate(ctr_trainer.model, test_dataloader)
79
-
80
-
81
- ```
82
-
83
-
84
-
85
-
86
-
87
- > **Note:**
88
- >
89
- > 所有模型均在大多数论文提及的多个知名公开数据集中测试,达到或者接近论文性能。
90
- >
91
- > 使用案例:[Examples](./examples)
92
- >
93
- > 每个数据集将会提供
94
- >
95
- > - 一个使用脚本,包含样本生成、模型训练与测试,并提供一套测评所用参数。
96
- > - 一个预处理脚本,将原始数据进行预处理,转化成csv。
97
- > - 数据格式参考文件(100条)。
98
- > - 全量数据,统一的csv文件,提供高速网盘下载链接和原始数据链接。
99
-
100
-
101
-
102
- [初步规划TODO清单](https://user-images.githubusercontent.com/11856746/167436396-f9c5de5b-d341-4697-8b91-884d4ae552be.png)
103
-
104
-
105
-
@@ -1,26 +0,0 @@
1
- torch_rechub/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- torch_rechub/basic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- torch_rechub/basic/activation.py,sha256=aIvomPA2X00dl3svv4lU2a4TvvltWxuqy22oML0e5YI,1519
4
- torch_rechub/basic/callback.py,sha256=Z62CIrPF2axyBSTRUnbc--9ADH8oF9YLTvLtxzvsMYE,953
5
- torch_rechub/basic/features.py,sha256=T4mA8I4JgCFZR0BZLrLRxKUdXa0Aaum9tJVlN3pdk_Y,1924
6
- torch_rechub/basic/layers.py,sha256=VCiaK0G4asvvzH4f7TobA7tyVhzgmc3FzvFK783z5yI,9354
7
- torch_rechub/basic/utils.py,sha256=v1NGrWVdj_CZBA7bERPEWqX5TcnQm0rVMDebDGXH17s,6581
8
- torch_rechub/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
- torch_rechub/models/matching/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
- torch_rechub/models/multi_task/__init__.py,sha256=tfy0NEOZlSkpSv6GSM3IFxPFGTrV3lpEXPAIGfUfvOM,106
11
- torch_rechub/models/multi_task/esmm.py,sha256=h7xSOxegyxfic5VFCEXEXtJLXanuRNBZ_T_rMMKZyyM,2168
12
- torch_rechub/models/multi_task/mmoe.py,sha256=2O4qUcmCMUKUW2gES6R84B2NTuKulgo6V01umhXPRQk,2818
13
- torch_rechub/models/multi_task/ple.py,sha256=m7k2D146Qg2ppKDtbXsPDFwXtSTjpsYO2hAFbFue7zQ,6467
14
- torch_rechub/models/multi_task/shared_bottom.py,sha256=3sBN5oaM912vJ1jzCpzaScnILIxobIA566cvQGS-TWQ,1855
15
- torch_rechub/models/ranking/__init__.py,sha256=k8fO0qs0WdbMcbOft2aV7NhXkmOAmBVTrQ6I1IMX7q4,78
16
- torch_rechub/models/ranking/deepfm.py,sha256=s3dAf5CYfNVHI62wXR2CZujmZ9V3yBdOBG6Y1ZgKOQc,1774
17
- torch_rechub/models/ranking/din.py,sha256=7UxS4ufcAi8ladJY68d3WEbn4986zHre8LaB4s9OVSs,4500
18
- torch_rechub/models/ranking/widedeep.py,sha256=ZpjhmIrrLygThUNP4ZWplgSboizrQjioEdy3SmGt3pM,1634
19
- torch_rechub/trainers/__init__.py,sha256=-HYw87_OnPJtwyDSUFioBNDJ7aiMM-bZNmiqV9vbfpw,67
20
- torch_rechub/trainers/mtl_trainer.py,sha256=yD98xeQuaI6cY6sN3ytDqrl4OQX113UA2kstDLC1rT4,6942
21
- torch_rechub/trainers/trainer.py,sha256=J7GhFed1QzBsUKntJD8qA2gMEtxAnx7haznpiousIk0,5052
22
- torch_rechub-0.0.1.dist-info/LICENSE,sha256=uz-kPWsf6PAphHAysV_p4A_MUcogaJbNrjPlINUC1x4,1067
23
- torch_rechub-0.0.1.dist-info/METADATA,sha256=XOlqfEwVKBuTpR2uXoavaT1hWC8AksgTPNlYy8-rcrk,3622
24
- torch_rechub-0.0.1.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
25
- torch_rechub-0.0.1.dist-info/top_level.txt,sha256=_RPvAnlLHdll9u2d74aFr_oOAj_NxgfLpH02Uifz_YY,13
26
- torch_rechub-0.0.1.dist-info/RECORD,,
@@ -1 +0,0 @@
1
- torch_rechub