torch-rechub 0.0.3__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/__init__.py +14 -0
- torch_rechub/basic/activation.py +54 -54
- torch_rechub/basic/callback.py +33 -33
- torch_rechub/basic/features.py +87 -94
- torch_rechub/basic/initializers.py +92 -92
- torch_rechub/basic/layers.py +994 -720
- torch_rechub/basic/loss_func.py +223 -34
- torch_rechub/basic/metaoptimizer.py +76 -72
- torch_rechub/basic/metric.py +251 -250
- torch_rechub/models/generative/__init__.py +6 -0
- torch_rechub/models/generative/hllm.py +249 -0
- torch_rechub/models/generative/hstu.py +189 -0
- torch_rechub/models/matching/__init__.py +13 -11
- torch_rechub/models/matching/comirec.py +193 -188
- torch_rechub/models/matching/dssm.py +72 -66
- torch_rechub/models/matching/dssm_facebook.py +77 -79
- torch_rechub/models/matching/dssm_senet.py +28 -16
- torch_rechub/models/matching/gru4rec.py +85 -87
- torch_rechub/models/matching/mind.py +103 -101
- torch_rechub/models/matching/narm.py +82 -76
- torch_rechub/models/matching/sasrec.py +143 -140
- torch_rechub/models/matching/sine.py +148 -151
- torch_rechub/models/matching/stamp.py +81 -83
- torch_rechub/models/matching/youtube_dnn.py +75 -71
- torch_rechub/models/matching/youtube_sbc.py +98 -98
- torch_rechub/models/multi_task/__init__.py +7 -5
- torch_rechub/models/multi_task/aitm.py +83 -84
- torch_rechub/models/multi_task/esmm.py +56 -55
- torch_rechub/models/multi_task/mmoe.py +58 -58
- torch_rechub/models/multi_task/ple.py +116 -130
- torch_rechub/models/multi_task/shared_bottom.py +45 -45
- torch_rechub/models/ranking/__init__.py +14 -11
- torch_rechub/models/ranking/afm.py +65 -63
- torch_rechub/models/ranking/autoint.py +102 -0
- torch_rechub/models/ranking/bst.py +61 -63
- torch_rechub/models/ranking/dcn.py +38 -38
- torch_rechub/models/ranking/dcn_v2.py +59 -69
- torch_rechub/models/ranking/deepffm.py +131 -123
- torch_rechub/models/ranking/deepfm.py +43 -42
- torch_rechub/models/ranking/dien.py +191 -191
- torch_rechub/models/ranking/din.py +93 -91
- torch_rechub/models/ranking/edcn.py +101 -117
- torch_rechub/models/ranking/fibinet.py +42 -50
- torch_rechub/models/ranking/widedeep.py +41 -41
- torch_rechub/trainers/__init__.py +4 -3
- torch_rechub/trainers/ctr_trainer.py +191 -128
- torch_rechub/trainers/match_trainer.py +239 -170
- torch_rechub/trainers/matching.md +3 -0
- torch_rechub/trainers/mtl_trainer.py +259 -207
- torch_rechub/trainers/seq_trainer.py +293 -0
- torch_rechub/utils/data.py +492 -360
- torch_rechub/utils/hstu_utils.py +198 -0
- torch_rechub/utils/match.py +457 -274
- torch_rechub/utils/mtl.py +136 -126
- torch_rechub/utils/onnx_export.py +353 -0
- torch_rechub-0.0.4.dist-info/METADATA +391 -0
- torch_rechub-0.0.4.dist-info/RECORD +62 -0
- {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.4.dist-info}/WHEEL +1 -2
- {torch_rechub-0.0.3.dist-info → torch_rechub-0.0.4.dist-info/licenses}/LICENSE +21 -21
- torch_rechub-0.0.3.dist-info/METADATA +0 -177
- torch_rechub-0.0.3.dist-info/RECORD +0 -55
- torch_rechub-0.0.3.dist-info/top_level.txt +0 -1
|
@@ -1,128 +1,191 @@
|
|
|
1
|
-
import os
|
|
2
|
-
|
|
3
|
-
import
|
|
4
|
-
|
|
5
|
-
from
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
self.model
|
|
48
|
-
if
|
|
49
|
-
|
|
50
|
-
self.
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
self.
|
|
54
|
-
|
|
55
|
-
self.
|
|
56
|
-
self.
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
self.
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
self.
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import tqdm
|
|
5
|
+
from sklearn.metrics import roc_auc_score
|
|
6
|
+
|
|
7
|
+
from ..basic.callback import EarlyStopper
|
|
8
|
+
from ..basic.loss_func import RegularizationLoss
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class CTRTrainer(object):
|
|
12
|
+
"""A general trainer for single task learning.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
model (nn.Module): any multi task learning model.
|
|
16
|
+
optimizer_fn (torch.optim): optimizer function of pytorch (default = `torch.optim.Adam`).
|
|
17
|
+
optimizer_params (dict): parameters of optimizer_fn.
|
|
18
|
+
scheduler_fn (torch.optim.lr_scheduler) : torch scheduling class, eg. `torch.optim.lr_scheduler.StepLR`.
|
|
19
|
+
scheduler_params (dict): parameters of optimizer scheduler_fn.
|
|
20
|
+
n_epoch (int): epoch number of training.
|
|
21
|
+
earlystop_patience (int): how long to wait after last time validation auc improved (default=10).
|
|
22
|
+
device (str): `"cpu"` or `"cuda:0"`
|
|
23
|
+
gpus (list): id of multi gpu (default=[]). If the length >=1, then the model will wrapped by nn.DataParallel.
|
|
24
|
+
loss_mode (int, optional): the training mode, `{0:point-wise, 1:pair-wise, 2:list-wise}`. Defaults to 0.
|
|
25
|
+
model_path (str): the path you want to save the model (default="./"). Note only save the best weight in the validation data.
|
|
26
|
+
embedding_l1 (float): L1 regularization coefficient for embedding parameters (default=0.0).
|
|
27
|
+
embedding_l2 (float): L2 regularization coefficient for embedding parameters (default=0.0).
|
|
28
|
+
dense_l1 (float): L1 regularization coefficient for dense parameters (default=0.0).
|
|
29
|
+
dense_l2 (float): L2 regularization coefficient for dense parameters (default=0.0).
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
model,
|
|
35
|
+
optimizer_fn=torch.optim.Adam,
|
|
36
|
+
optimizer_params=None,
|
|
37
|
+
regularization_params=None,
|
|
38
|
+
scheduler_fn=None,
|
|
39
|
+
scheduler_params=None,
|
|
40
|
+
n_epoch=10,
|
|
41
|
+
earlystop_patience=10,
|
|
42
|
+
device="cpu",
|
|
43
|
+
gpus=None,
|
|
44
|
+
loss_mode=True,
|
|
45
|
+
model_path="./",
|
|
46
|
+
):
|
|
47
|
+
self.model = model # for uniform weights save method in one gpu or multi gpu
|
|
48
|
+
if gpus is None:
|
|
49
|
+
gpus = []
|
|
50
|
+
self.gpus = gpus
|
|
51
|
+
if len(gpus) > 1:
|
|
52
|
+
print('parallel running on these gpus:', gpus)
|
|
53
|
+
self.model = torch.nn.DataParallel(self.model, device_ids=gpus)
|
|
54
|
+
# torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
55
|
+
self.device = torch.device(device)
|
|
56
|
+
self.model.to(self.device)
|
|
57
|
+
if optimizer_params is None:
|
|
58
|
+
optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
|
|
59
|
+
self.optimizer = optimizer_fn(self.model.parameters(), **optimizer_params) # default optimizer
|
|
60
|
+
if regularization_params is None:
|
|
61
|
+
regularization_params = {"embedding_l1": 0.0, "embedding_l2": 0.0, "dense_l1": 0.0, "dense_l2": 0.0}
|
|
62
|
+
self.scheduler = None
|
|
63
|
+
if scheduler_fn is not None:
|
|
64
|
+
self.scheduler = scheduler_fn(self.optimizer, **scheduler_params)
|
|
65
|
+
self.loss_mode = loss_mode
|
|
66
|
+
self.criterion = torch.nn.BCELoss() # default loss cross_entropy
|
|
67
|
+
self.evaluate_fn = roc_auc_score # default evaluate function
|
|
68
|
+
self.n_epoch = n_epoch
|
|
69
|
+
self.early_stopper = EarlyStopper(patience=earlystop_patience)
|
|
70
|
+
self.model_path = model_path
|
|
71
|
+
# Initialize regularization loss
|
|
72
|
+
self.reg_loss_fn = RegularizationLoss(**regularization_params)
|
|
73
|
+
|
|
74
|
+
def train_one_epoch(self, data_loader, log_interval=10):
|
|
75
|
+
self.model.train()
|
|
76
|
+
total_loss = 0
|
|
77
|
+
tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
|
|
78
|
+
for i, (x_dict, y) in enumerate(tk0):
|
|
79
|
+
x_dict = {k: v.to(self.device) for k, v in x_dict.items()} # tensor to GPU
|
|
80
|
+
y = y.to(self.device).float()
|
|
81
|
+
if self.loss_mode:
|
|
82
|
+
y_pred = self.model(x_dict)
|
|
83
|
+
loss = self.criterion(y_pred, y)
|
|
84
|
+
else:
|
|
85
|
+
y_pred, other_loss = self.model(x_dict)
|
|
86
|
+
loss = self.criterion(y_pred, y) + other_loss
|
|
87
|
+
|
|
88
|
+
# Add regularization loss
|
|
89
|
+
reg_loss = self.reg_loss_fn(self.model)
|
|
90
|
+
loss = loss + reg_loss
|
|
91
|
+
|
|
92
|
+
self.model.zero_grad()
|
|
93
|
+
loss.backward()
|
|
94
|
+
self.optimizer.step()
|
|
95
|
+
total_loss += loss.item()
|
|
96
|
+
if (i + 1) % log_interval == 0:
|
|
97
|
+
tk0.set_postfix(loss=total_loss / log_interval)
|
|
98
|
+
total_loss = 0
|
|
99
|
+
|
|
100
|
+
def fit(self, train_dataloader, val_dataloader=None):
|
|
101
|
+
for epoch_i in range(self.n_epoch):
|
|
102
|
+
print('epoch:', epoch_i)
|
|
103
|
+
self.train_one_epoch(train_dataloader)
|
|
104
|
+
if self.scheduler is not None:
|
|
105
|
+
if epoch_i % self.scheduler.step_size == 0:
|
|
106
|
+
print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
|
|
107
|
+
self.scheduler.step() # update lr in epoch level by scheduler
|
|
108
|
+
if val_dataloader:
|
|
109
|
+
auc = self.evaluate(self.model, val_dataloader)
|
|
110
|
+
print('epoch:', epoch_i, 'validation: auc:', auc)
|
|
111
|
+
if self.early_stopper.stop_training(auc, self.model.state_dict()):
|
|
112
|
+
print(f'validation: best auc: {self.early_stopper.best_auc}')
|
|
113
|
+
self.model.load_state_dict(self.early_stopper.best_weights)
|
|
114
|
+
break
|
|
115
|
+
torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) # save best auc model
|
|
116
|
+
|
|
117
|
+
def evaluate(self, model, data_loader):
|
|
118
|
+
model.eval()
|
|
119
|
+
targets, predicts = list(), list()
|
|
120
|
+
with torch.no_grad():
|
|
121
|
+
tk0 = tqdm.tqdm(data_loader, desc="validation", smoothing=0, mininterval=1.0)
|
|
122
|
+
for i, (x_dict, y) in enumerate(tk0):
|
|
123
|
+
x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
|
|
124
|
+
# 确保y是float类型且维度为[batch_size, 1]
|
|
125
|
+
y = y.to(self.device).float().view(-1, 1)
|
|
126
|
+
if self.loss_mode:
|
|
127
|
+
y_pred = model(x_dict)
|
|
128
|
+
else:
|
|
129
|
+
y_pred, _ = model(x_dict)
|
|
130
|
+
targets.extend(y.tolist())
|
|
131
|
+
predicts.extend(y_pred.tolist())
|
|
132
|
+
return self.evaluate_fn(targets, predicts)
|
|
133
|
+
|
|
134
|
+
def predict(self, model, data_loader):
|
|
135
|
+
model.eval()
|
|
136
|
+
predicts = list()
|
|
137
|
+
with torch.no_grad():
|
|
138
|
+
tk0 = tqdm.tqdm(data_loader, desc="predict", smoothing=0, mininterval=1.0)
|
|
139
|
+
for i, (x_dict, y) in enumerate(tk0):
|
|
140
|
+
x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
|
|
141
|
+
y = y.to(self.device)
|
|
142
|
+
if self.loss_mode:
|
|
143
|
+
y_pred = model(x_dict)
|
|
144
|
+
else:
|
|
145
|
+
y_pred, _ = model(x_dict)
|
|
146
|
+
predicts.extend(y_pred.tolist())
|
|
147
|
+
return predicts
|
|
148
|
+
|
|
149
|
+
def export_onnx(self, output_path, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False):
|
|
150
|
+
"""Export the trained model to ONNX format.
|
|
151
|
+
|
|
152
|
+
This method exports the ranking model (e.g., DeepFM, WideDeep, DCN) to ONNX format
|
|
153
|
+
for deployment. The export is non-invasive and does not modify the model code.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
output_path (str): Path to save the ONNX model file.
|
|
157
|
+
dummy_input (dict, optional): Example input dict {feature_name: tensor}.
|
|
158
|
+
If not provided, dummy inputs will be generated automatically.
|
|
159
|
+
batch_size (int): Batch size for auto-generated dummy input (default: 2).
|
|
160
|
+
seq_length (int): Sequence length for SequenceFeature (default: 10).
|
|
161
|
+
opset_version (int): ONNX opset version (default: 14).
|
|
162
|
+
dynamic_batch (bool): Enable dynamic batch size (default: True).
|
|
163
|
+
device (str, optional): Device for export ('cpu', 'cuda', etc.).
|
|
164
|
+
If None, defaults to 'cpu' for maximum compatibility.
|
|
165
|
+
verbose (bool): Print export details (default: False).
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
bool: True if export succeeded, False otherwise.
|
|
169
|
+
|
|
170
|
+
Example:
|
|
171
|
+
>>> trainer = CTRTrainer(model, ...)
|
|
172
|
+
>>> trainer.fit(train_dl, val_dl)
|
|
173
|
+
>>> trainer.export_onnx("deepfm.onnx")
|
|
174
|
+
|
|
175
|
+
>>> # With custom dummy input
|
|
176
|
+
>>> dummy = {"user_id": torch.tensor([1, 2]), "item_id": torch.tensor([10, 20])}
|
|
177
|
+
>>> trainer.export_onnx("model.onnx", dummy_input=dummy)
|
|
178
|
+
|
|
179
|
+
>>> # Export on specific device
|
|
180
|
+
>>> trainer.export_onnx("model.onnx", device="cpu")
|
|
181
|
+
"""
|
|
182
|
+
from ..utils.onnx_export import ONNXExporter
|
|
183
|
+
|
|
184
|
+
# Handle DataParallel wrapped model
|
|
185
|
+
model = self.model.module if hasattr(self.model, 'module') else self.model
|
|
186
|
+
|
|
187
|
+
# Use provided device or default to 'cpu'
|
|
188
|
+
export_device = device if device is not None else 'cpu'
|
|
189
|
+
|
|
190
|
+
exporter = ONNXExporter(model, device=export_device)
|
|
191
|
+
return exporter.export(output_path=output_path, dummy_input=dummy_input, batch_size=batch_size, seq_length=seq_length, opset_version=opset_version, dynamic_batch=dynamic_batch, verbose=verbose)
|