torch-rechub 0.0.1__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_rechub/__init__.py +14 -0
- torch_rechub/basic/activation.py +3 -1
- torch_rechub/basic/callback.py +2 -2
- torch_rechub/basic/features.py +38 -8
- torch_rechub/basic/initializers.py +92 -0
- torch_rechub/basic/layers.py +800 -46
- torch_rechub/basic/loss_func.py +223 -0
- torch_rechub/basic/metaoptimizer.py +76 -0
- torch_rechub/basic/metric.py +251 -0
- torch_rechub/models/generative/__init__.py +6 -0
- torch_rechub/models/generative/hllm.py +249 -0
- torch_rechub/models/generative/hstu.py +189 -0
- torch_rechub/models/matching/__init__.py +13 -0
- torch_rechub/models/matching/comirec.py +193 -0
- torch_rechub/models/matching/dssm.py +72 -0
- torch_rechub/models/matching/dssm_facebook.py +77 -0
- torch_rechub/models/matching/dssm_senet.py +87 -0
- torch_rechub/models/matching/gru4rec.py +85 -0
- torch_rechub/models/matching/mind.py +103 -0
- torch_rechub/models/matching/narm.py +82 -0
- torch_rechub/models/matching/sasrec.py +143 -0
- torch_rechub/models/matching/sine.py +148 -0
- torch_rechub/models/matching/stamp.py +81 -0
- torch_rechub/models/matching/youtube_dnn.py +75 -0
- torch_rechub/models/matching/youtube_sbc.py +98 -0
- torch_rechub/models/multi_task/__init__.py +5 -2
- torch_rechub/models/multi_task/aitm.py +83 -0
- torch_rechub/models/multi_task/esmm.py +19 -8
- torch_rechub/models/multi_task/mmoe.py +18 -12
- torch_rechub/models/multi_task/ple.py +41 -29
- torch_rechub/models/multi_task/shared_bottom.py +3 -2
- torch_rechub/models/ranking/__init__.py +13 -2
- torch_rechub/models/ranking/afm.py +65 -0
- torch_rechub/models/ranking/autoint.py +102 -0
- torch_rechub/models/ranking/bst.py +61 -0
- torch_rechub/models/ranking/dcn.py +38 -0
- torch_rechub/models/ranking/dcn_v2.py +59 -0
- torch_rechub/models/ranking/deepffm.py +131 -0
- torch_rechub/models/ranking/deepfm.py +8 -7
- torch_rechub/models/ranking/dien.py +191 -0
- torch_rechub/models/ranking/din.py +31 -19
- torch_rechub/models/ranking/edcn.py +101 -0
- torch_rechub/models/ranking/fibinet.py +42 -0
- torch_rechub/models/ranking/widedeep.py +6 -6
- torch_rechub/trainers/__init__.py +4 -2
- torch_rechub/trainers/ctr_trainer.py +191 -0
- torch_rechub/trainers/match_trainer.py +239 -0
- torch_rechub/trainers/matching.md +3 -0
- torch_rechub/trainers/mtl_trainer.py +137 -23
- torch_rechub/trainers/seq_trainer.py +293 -0
- torch_rechub/utils/__init__.py +0 -0
- torch_rechub/utils/data.py +492 -0
- torch_rechub/utils/hstu_utils.py +198 -0
- torch_rechub/utils/match.py +457 -0
- torch_rechub/utils/mtl.py +136 -0
- torch_rechub/utils/onnx_export.py +353 -0
- torch_rechub-0.0.4.dist-info/METADATA +391 -0
- torch_rechub-0.0.4.dist-info/RECORD +62 -0
- {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.4.dist-info}/WHEEL +1 -2
- {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.4.dist-info/licenses}/LICENSE +1 -1
- torch_rechub/basic/utils.py +0 -168
- torch_rechub/trainers/trainer.py +0 -111
- torch_rechub-0.0.1.dist-info/METADATA +0 -105
- torch_rechub-0.0.1.dist-info/RECORD +0 -26
- torch_rechub-0.0.1.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import tqdm
|
|
5
|
+
from sklearn.metrics import roc_auc_score
|
|
6
|
+
|
|
7
|
+
from ..basic.callback import EarlyStopper
|
|
8
|
+
from ..basic.loss_func import BPRLoss, RegularizationLoss
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class MatchTrainer(object):
|
|
12
|
+
"""A general trainer for Matching/Retrieval
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
model (nn.Module): any matching model.
|
|
16
|
+
mode (int, optional): the training mode, `{0:point-wise, 1:pair-wise, 2:list-wise}`. Defaults to 0.
|
|
17
|
+
optimizer_fn (torch.optim): optimizer function of pytorch (default = `torch.optim.Adam`).
|
|
18
|
+
optimizer_params (dict): parameters of optimizer_fn.
|
|
19
|
+
scheduler_fn (torch.optim.lr_scheduler) : torch scheduling class, eg. `torch.optim.lr_scheduler.StepLR`.
|
|
20
|
+
scheduler_params (dict): parameters of optimizer scheduler_fn.
|
|
21
|
+
n_epoch (int): epoch number of training.
|
|
22
|
+
earlystop_patience (int): how long to wait after last time validation auc improved (default=10).
|
|
23
|
+
device (str): `"cpu"` or `"cuda:0"`
|
|
24
|
+
gpus (list): id of multi gpu (default=[]). If the length >=1, then the model will wrapped by nn.DataParallel.
|
|
25
|
+
model_path (str): the path you want to save the model (default="./"). Note only save the best weight in the validation data.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
model,
|
|
31
|
+
mode=0,
|
|
32
|
+
optimizer_fn=torch.optim.Adam,
|
|
33
|
+
optimizer_params=None,
|
|
34
|
+
regularization_params=None,
|
|
35
|
+
scheduler_fn=None,
|
|
36
|
+
scheduler_params=None,
|
|
37
|
+
n_epoch=10,
|
|
38
|
+
earlystop_patience=10,
|
|
39
|
+
device="cpu",
|
|
40
|
+
gpus=None,
|
|
41
|
+
model_path="./",
|
|
42
|
+
):
|
|
43
|
+
self.model = model # for uniform weights save method in one gpu or multi gpu
|
|
44
|
+
if gpus is None:
|
|
45
|
+
gpus = []
|
|
46
|
+
self.gpus = gpus
|
|
47
|
+
if len(gpus) > 1:
|
|
48
|
+
print('parallel running on these gpus:', gpus)
|
|
49
|
+
self.model = torch.nn.DataParallel(self.model, device_ids=gpus)
|
|
50
|
+
# torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
51
|
+
self.device = torch.device(device)
|
|
52
|
+
self.model.to(self.device)
|
|
53
|
+
if optimizer_params is None:
|
|
54
|
+
optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
|
|
55
|
+
if regularization_params is None:
|
|
56
|
+
regularization_params = {"embedding_l1": 0.0, "embedding_l2": 0.0, "dense_l1": 0.0, "dense_l2": 0.0}
|
|
57
|
+
self.mode = mode
|
|
58
|
+
if mode == 0: # point-wise loss, binary cross_entropy
|
|
59
|
+
self.criterion = torch.nn.BCELoss() # default loss binary cross_entropy
|
|
60
|
+
elif mode == 1: # pair-wise loss
|
|
61
|
+
self.criterion = BPRLoss()
|
|
62
|
+
elif mode == 2: # list-wise loss, softmax
|
|
63
|
+
self.criterion = torch.nn.CrossEntropyLoss()
|
|
64
|
+
else:
|
|
65
|
+
raise ValueError("mode only contain value in %s, but got %s" % ([0, 1, 2], mode))
|
|
66
|
+
self.optimizer = optimizer_fn(self.model.parameters(), **optimizer_params) # default optimizer
|
|
67
|
+
self.scheduler = None
|
|
68
|
+
if scheduler_fn is not None:
|
|
69
|
+
self.scheduler = scheduler_fn(self.optimizer, **scheduler_params)
|
|
70
|
+
self.evaluate_fn = roc_auc_score # default evaluate function
|
|
71
|
+
self.n_epoch = n_epoch
|
|
72
|
+
self.early_stopper = EarlyStopper(patience=earlystop_patience)
|
|
73
|
+
self.model_path = model_path
|
|
74
|
+
# Initialize regularization loss
|
|
75
|
+
self.reg_loss_fn = RegularizationLoss(**regularization_params)
|
|
76
|
+
|
|
77
|
+
def train_one_epoch(self, data_loader, log_interval=10):
|
|
78
|
+
self.model.train()
|
|
79
|
+
total_loss = 0
|
|
80
|
+
tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
|
|
81
|
+
for i, (x_dict, y) in enumerate(tk0):
|
|
82
|
+
x_dict = {k: v.to(self.device) for k, v in x_dict.items()} # tensor to GPU
|
|
83
|
+
y = y.to(self.device)
|
|
84
|
+
if self.mode == 0:
|
|
85
|
+
y = y.float() # torch._C._nn.binary_cross_entropy expected Float
|
|
86
|
+
else:
|
|
87
|
+
y = y.long() #
|
|
88
|
+
if self.mode == 1: # pair_wise
|
|
89
|
+
pos_score, neg_score = self.model(x_dict)
|
|
90
|
+
loss = self.criterion(pos_score, neg_score)
|
|
91
|
+
else:
|
|
92
|
+
y_pred = self.model(x_dict)
|
|
93
|
+
loss = self.criterion(y_pred, y)
|
|
94
|
+
|
|
95
|
+
# Add regularization loss
|
|
96
|
+
reg_loss = self.reg_loss_fn(self.model)
|
|
97
|
+
loss = loss + reg_loss
|
|
98
|
+
|
|
99
|
+
# used for debug
|
|
100
|
+
# if i == 0:
|
|
101
|
+
# print()
|
|
102
|
+
# if self.mode == 0:
|
|
103
|
+
# print('pred: ', [f'{float(each):5.2g}' for each in y_pred.detach().cpu().tolist()])
|
|
104
|
+
# print('truth:', [f'{float(each):5.2g}' for each in y.detach().cpu().tolist()])
|
|
105
|
+
# elif self.mode == 2:
|
|
106
|
+
# pred = y_pred.detach().cpu().mean(0)
|
|
107
|
+
# pred = torch.softmax(pred, dim=0).tolist()
|
|
108
|
+
# print('pred: ', [f'{float(each):4.2g}' for each in pred])
|
|
109
|
+
# elif self.mode == 1:
|
|
110
|
+
# print('pos:', [f'{float(each):5.2g}' for each in pos_score.detach().cpu().tolist()])
|
|
111
|
+
# print('neg: ', [f'{float(each):5.2g}' for each in neg_score.detach().cpu().tolist()])
|
|
112
|
+
|
|
113
|
+
self.model.zero_grad()
|
|
114
|
+
loss.backward()
|
|
115
|
+
self.optimizer.step()
|
|
116
|
+
total_loss += loss.item()
|
|
117
|
+
if (i + 1) % log_interval == 0:
|
|
118
|
+
tk0.set_postfix(loss=total_loss / log_interval)
|
|
119
|
+
total_loss = 0
|
|
120
|
+
|
|
121
|
+
def fit(self, train_dataloader, val_dataloader=None):
|
|
122
|
+
for epoch_i in range(self.n_epoch):
|
|
123
|
+
print('epoch:', epoch_i)
|
|
124
|
+
self.train_one_epoch(train_dataloader)
|
|
125
|
+
if self.scheduler is not None:
|
|
126
|
+
if epoch_i % self.scheduler.step_size == 0:
|
|
127
|
+
print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
|
|
128
|
+
self.scheduler.step() # update lr in epoch level by scheduler
|
|
129
|
+
|
|
130
|
+
if val_dataloader:
|
|
131
|
+
auc = self.evaluate(self.model, val_dataloader)
|
|
132
|
+
print('epoch:', epoch_i, 'validation: auc:', auc)
|
|
133
|
+
if self.early_stopper.stop_training(auc, self.model.state_dict()):
|
|
134
|
+
print(f'validation: best auc: {self.early_stopper.best_auc}')
|
|
135
|
+
self.model.load_state_dict(self.early_stopper.best_weights)
|
|
136
|
+
break
|
|
137
|
+
torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) # save best auc model
|
|
138
|
+
|
|
139
|
+
def evaluate(self, model, data_loader):
|
|
140
|
+
model.eval()
|
|
141
|
+
targets, predicts = list(), list()
|
|
142
|
+
with torch.no_grad():
|
|
143
|
+
tk0 = tqdm.tqdm(data_loader, desc="validation", smoothing=0, mininterval=1.0)
|
|
144
|
+
for i, (x_dict, y) in enumerate(tk0):
|
|
145
|
+
x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
|
|
146
|
+
y = y.to(self.device)
|
|
147
|
+
y_pred = model(x_dict)
|
|
148
|
+
targets.extend(y.tolist())
|
|
149
|
+
predicts.extend(y_pred.tolist())
|
|
150
|
+
return self.evaluate_fn(targets, predicts)
|
|
151
|
+
|
|
152
|
+
def predict(self, model, data_loader):
|
|
153
|
+
model.eval()
|
|
154
|
+
predicts = list()
|
|
155
|
+
with torch.no_grad():
|
|
156
|
+
tk0 = tqdm.tqdm(data_loader, desc="predict", smoothing=0, mininterval=1.0)
|
|
157
|
+
for i, (x_dict, y) in enumerate(tk0):
|
|
158
|
+
x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
|
|
159
|
+
y = y.to(self.device)
|
|
160
|
+
y_pred = model(x_dict)
|
|
161
|
+
predicts.extend(y_pred.tolist())
|
|
162
|
+
return predicts
|
|
163
|
+
|
|
164
|
+
def inference_embedding(self, model, mode, data_loader, model_path):
|
|
165
|
+
# inference
|
|
166
|
+
assert mode in ["user", "item"], "Invalid mode={}.".format(mode)
|
|
167
|
+
model.mode = mode
|
|
168
|
+
model.load_state_dict(torch.load(os.path.join(model_path, "model.pth"), map_location=self.device, weights_only=True))
|
|
169
|
+
model = model.to(self.device)
|
|
170
|
+
model.eval()
|
|
171
|
+
predicts = []
|
|
172
|
+
with torch.no_grad():
|
|
173
|
+
tk0 = tqdm.tqdm(data_loader, desc="%s inference" % (mode), smoothing=0, mininterval=1.0)
|
|
174
|
+
for i, x_dict in enumerate(tk0):
|
|
175
|
+
x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
|
|
176
|
+
y_pred = model(x_dict)
|
|
177
|
+
predicts.append(y_pred.data)
|
|
178
|
+
return torch.cat(predicts, dim=0)
|
|
179
|
+
|
|
180
|
+
def export_onnx(self, output_path, mode=None, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False):
|
|
181
|
+
"""Export the trained matching model to ONNX format.
|
|
182
|
+
|
|
183
|
+
This method exports matching/retrieval models (e.g., DSSM, YoutubeDNN, MIND)
|
|
184
|
+
to ONNX format. For dual-tower models, you can export user tower and item
|
|
185
|
+
tower separately for efficient online serving.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
output_path (str): Path to save the ONNX model file.
|
|
189
|
+
mode (str, optional): Export mode for dual-tower models:
|
|
190
|
+
- "user": Export only the user tower (for user embedding inference)
|
|
191
|
+
- "item": Export only the item tower (for item embedding inference)
|
|
192
|
+
- None: Export the full model (default)
|
|
193
|
+
dummy_input (dict, optional): Example input dict {feature_name: tensor}.
|
|
194
|
+
If not provided, dummy inputs will be generated automatically.
|
|
195
|
+
batch_size (int): Batch size for auto-generated dummy input (default: 2).
|
|
196
|
+
seq_length (int): Sequence length for SequenceFeature (default: 10).
|
|
197
|
+
opset_version (int): ONNX opset version (default: 14).
|
|
198
|
+
dynamic_batch (bool): Enable dynamic batch size (default: True).
|
|
199
|
+
device (str, optional): Device for export ('cpu', 'cuda', etc.).
|
|
200
|
+
If None, defaults to 'cpu' for maximum compatibility.
|
|
201
|
+
verbose (bool): Print export details (default: False).
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
bool: True if export succeeded, False otherwise.
|
|
205
|
+
|
|
206
|
+
Example:
|
|
207
|
+
>>> trainer = MatchTrainer(dssm_model, mode=0, ...)
|
|
208
|
+
>>> trainer.fit(train_dl)
|
|
209
|
+
|
|
210
|
+
>>> # Export user tower for user embedding inference
|
|
211
|
+
>>> trainer.export_onnx("user_tower.onnx", mode="user")
|
|
212
|
+
|
|
213
|
+
>>> # Export item tower for item embedding inference
|
|
214
|
+
>>> trainer.export_onnx("item_tower.onnx", mode="item")
|
|
215
|
+
|
|
216
|
+
>>> # Export full model (for online similarity computation)
|
|
217
|
+
>>> trainer.export_onnx("full_model.onnx")
|
|
218
|
+
|
|
219
|
+
>>> # Export on specific device
|
|
220
|
+
>>> trainer.export_onnx("user_tower.onnx", mode="user", device="cpu")
|
|
221
|
+
"""
|
|
222
|
+
from ..utils.onnx_export import ONNXExporter
|
|
223
|
+
|
|
224
|
+
# Handle DataParallel wrapped model
|
|
225
|
+
model = self.model.module if hasattr(self.model, 'module') else self.model
|
|
226
|
+
|
|
227
|
+
# Store original mode
|
|
228
|
+
original_mode = getattr(model, 'mode', None)
|
|
229
|
+
|
|
230
|
+
# Use provided device or default to 'cpu'
|
|
231
|
+
export_device = device if device is not None else 'cpu'
|
|
232
|
+
|
|
233
|
+
try:
|
|
234
|
+
exporter = ONNXExporter(model, device=export_device)
|
|
235
|
+
return exporter.export(output_path=output_path, mode=mode, dummy_input=dummy_input, batch_size=batch_size, seq_length=seq_length, opset_version=opset_version, dynamic_batch=dynamic_batch, verbose=verbose)
|
|
236
|
+
finally:
|
|
237
|
+
# Restore original mode
|
|
238
|
+
if hasattr(model, 'mode'):
|
|
239
|
+
model.mode = original_mode
|
|
@@ -1,11 +1,15 @@
|
|
|
1
1
|
import os
|
|
2
|
-
|
|
2
|
+
|
|
3
3
|
import numpy as np
|
|
4
4
|
import torch
|
|
5
5
|
import torch.nn as nn
|
|
6
|
+
import tqdm
|
|
7
|
+
|
|
6
8
|
from ..basic.callback import EarlyStopper
|
|
7
|
-
from ..basic.
|
|
9
|
+
from ..basic.loss_func import RegularizationLoss
|
|
8
10
|
from ..models.multi_task import ESMM
|
|
11
|
+
from ..utils.data import get_loss_func, get_metric_func
|
|
12
|
+
from ..utils.mtl import MetaBalance, gradnorm, shared_task_layers
|
|
9
13
|
|
|
10
14
|
|
|
11
15
|
class MTLTrainer(object):
|
|
@@ -18,7 +22,7 @@ class MTLTrainer(object):
|
|
|
18
22
|
optimizer_params (dict): parameters of optimizer_fn.
|
|
19
23
|
scheduler_fn (torch.optim.lr_scheduler) : torch scheduling class, eg. `torch.optim.lr_scheduler.StepLR`.
|
|
20
24
|
scheduler_params (dict): parameters of optimizer scheduler_fn.
|
|
21
|
-
adaptive_params (dict): parameters of adaptive loss weight method. Now only support `{"method" : "uwl"}`.
|
|
25
|
+
adaptive_params (dict): parameters of adaptive loss weight method. Now only support `{"method" : "uwl"}`.
|
|
22
26
|
n_epoch (int): epoch number of training.
|
|
23
27
|
earlystop_taskid (int): task id of earlystop metrics relies between multi task (default = 0).
|
|
24
28
|
earlystop_patience (int): how long to wait after last time validation auc improved (default = 10).
|
|
@@ -32,10 +36,8 @@ class MTLTrainer(object):
|
|
|
32
36
|
model,
|
|
33
37
|
task_types,
|
|
34
38
|
optimizer_fn=torch.optim.Adam,
|
|
35
|
-
optimizer_params=
|
|
36
|
-
|
|
37
|
-
"weight_decay": 1e-5
|
|
38
|
-
},
|
|
39
|
+
optimizer_params=None,
|
|
40
|
+
regularization_params=None,
|
|
39
41
|
scheduler_fn=None,
|
|
40
42
|
scheduler_params=None,
|
|
41
43
|
adaptive_params=None,
|
|
@@ -43,10 +45,16 @@ class MTLTrainer(object):
|
|
|
43
45
|
earlystop_taskid=0,
|
|
44
46
|
earlystop_patience=10,
|
|
45
47
|
device="cpu",
|
|
46
|
-
gpus=
|
|
48
|
+
gpus=None,
|
|
47
49
|
model_path="./",
|
|
48
50
|
):
|
|
49
51
|
self.model = model
|
|
52
|
+
if gpus is None:
|
|
53
|
+
gpus = []
|
|
54
|
+
if optimizer_params is None:
|
|
55
|
+
optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
|
|
56
|
+
if regularization_params is None:
|
|
57
|
+
regularization_params = {"embedding_l1": 0.0, "embedding_l2": 0.0, "dense_l1": 0.0, "dense_l2": 0.0}
|
|
50
58
|
self.task_types = task_types
|
|
51
59
|
self.n_task = len(task_types)
|
|
52
60
|
self.loss_weight = None
|
|
@@ -56,7 +64,27 @@ class MTLTrainer(object):
|
|
|
56
64
|
self.adaptive_method = "uwl"
|
|
57
65
|
self.loss_weight = nn.ParameterList(nn.Parameter(torch.zeros(1)) for _ in range(self.n_task))
|
|
58
66
|
self.model.add_module("loss weight", self.loss_weight)
|
|
59
|
-
|
|
67
|
+
elif adaptive_params["method"] == "metabalance":
|
|
68
|
+
self.adaptive_method = "metabalance"
|
|
69
|
+
share_layers, task_layers = shared_task_layers(self.model)
|
|
70
|
+
self.meta_optimizer = MetaBalance(share_layers)
|
|
71
|
+
self.share_optimizer = optimizer_fn(share_layers, **optimizer_params)
|
|
72
|
+
self.task_optimizer = optimizer_fn(task_layers, **optimizer_params)
|
|
73
|
+
elif adaptive_params["method"] == "gradnorm":
|
|
74
|
+
self.adaptive_method = "gradnorm"
|
|
75
|
+
self.alpha = adaptive_params.get("alpha", 0.16)
|
|
76
|
+
share_layers = shared_task_layers(self.model)[0]
|
|
77
|
+
# gradnorm calculate the gradients of each loss on the last
|
|
78
|
+
# fully connected shared layer weight(dimension is 2)
|
|
79
|
+
for i in range(len(share_layers)):
|
|
80
|
+
if share_layers[-i].ndim == 2:
|
|
81
|
+
self.last_share_layer = share_layers[-i]
|
|
82
|
+
break
|
|
83
|
+
self.initial_task_loss = None
|
|
84
|
+
self.loss_weight = nn.ParameterList(nn.Parameter(torch.ones(1)) for _ in range(self.n_task))
|
|
85
|
+
self.model.add_module("loss weight", self.loss_weight)
|
|
86
|
+
if self.adaptive_method != "metabalance":
|
|
87
|
+
self.optimizer = optimizer_fn(self.model.parameters(), **optimizer_params) # default Adam optimizer
|
|
60
88
|
self.scheduler = None
|
|
61
89
|
if scheduler_fn is not None:
|
|
62
90
|
self.scheduler = scheduler_fn(self.optimizer, **scheduler_params)
|
|
@@ -65,27 +93,32 @@ class MTLTrainer(object):
|
|
|
65
93
|
self.n_epoch = n_epoch
|
|
66
94
|
self.earlystop_taskid = earlystop_taskid
|
|
67
95
|
self.early_stopper = EarlyStopper(patience=earlystop_patience)
|
|
68
|
-
self.device = torch.device(device)
|
|
69
96
|
|
|
70
97
|
self.gpus = gpus
|
|
71
98
|
if len(gpus) > 1:
|
|
72
99
|
print('parallel running on these gpus:', gpus)
|
|
73
100
|
self.model = torch.nn.DataParallel(self.model, device_ids=gpus)
|
|
101
|
+
# torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
102
|
+
self.device = torch.device(device)
|
|
103
|
+
self.model.to(self.device)
|
|
74
104
|
self.model_path = model_path
|
|
105
|
+
# Initialize regularization loss
|
|
106
|
+
self.reg_loss_fn = RegularizationLoss(**regularization_params)
|
|
75
107
|
|
|
76
108
|
def train_one_epoch(self, data_loader):
|
|
77
109
|
self.model.train()
|
|
78
110
|
total_loss = np.zeros(self.n_task)
|
|
79
111
|
tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
|
|
80
112
|
for iter_i, (x_dict, ys) in enumerate(tk0):
|
|
81
|
-
x_dict = {k: v.to(self.device) for k, v in x_dict.items()} #tensor to GPU
|
|
113
|
+
x_dict = {k: v.to(self.device) for k, v in x_dict.items()} # tensor to GPU
|
|
82
114
|
ys = ys.to(self.device)
|
|
83
115
|
y_preds = self.model(x_dict)
|
|
84
116
|
loss_list = [self.loss_fns[i](y_preds[:, i], ys[:, i].float()) for i in range(self.n_task)]
|
|
85
117
|
if isinstance(self.model, ESMM):
|
|
86
|
-
|
|
118
|
+
# ESSM only compute loss for ctr and ctcvr task
|
|
119
|
+
loss = sum(loss_list[1:])
|
|
87
120
|
else:
|
|
88
|
-
if self.adaptive_method
|
|
121
|
+
if self.adaptive_method is not None:
|
|
89
122
|
if self.adaptive_method == "uwl":
|
|
90
123
|
loss = 0
|
|
91
124
|
for loss_i, w_i in zip(loss_list, self.loss_weight):
|
|
@@ -93,38 +126,74 @@ class MTLTrainer(object):
|
|
|
93
126
|
loss += 2 * loss_i * torch.exp(-w_i) + w_i
|
|
94
127
|
else:
|
|
95
128
|
loss = sum(loss_list) / self.n_task
|
|
96
|
-
|
|
97
|
-
loss
|
|
98
|
-
self.
|
|
129
|
+
|
|
130
|
+
# Add regularization loss
|
|
131
|
+
reg_loss = self.reg_loss_fn(self.model)
|
|
132
|
+
loss = loss + reg_loss
|
|
133
|
+
if self.adaptive_method == 'metabalance':
|
|
134
|
+
self.share_optimizer.zero_grad()
|
|
135
|
+
self.task_optimizer.zero_grad()
|
|
136
|
+
self.meta_optimizer.step(loss_list)
|
|
137
|
+
self.share_optimizer.step()
|
|
138
|
+
self.task_optimizer.step()
|
|
139
|
+
elif self.adaptive_method == "gradnorm":
|
|
140
|
+
self.optimizer.zero_grad()
|
|
141
|
+
if self.initial_task_loss is None:
|
|
142
|
+
self.initial_task_loss = [l.item() for l in loss_list]
|
|
143
|
+
gradnorm(loss_list, self.loss_weight, self.last_share_layer, self.initial_task_loss, self.alpha)
|
|
144
|
+
self.optimizer.step()
|
|
145
|
+
# renormalize
|
|
146
|
+
loss_weight_sum = sum([w.item() for w in self.loss_weight])
|
|
147
|
+
normalize_coeff = len(self.loss_weight) / loss_weight_sum
|
|
148
|
+
for w in self.loss_weight:
|
|
149
|
+
w.data = w.data * normalize_coeff
|
|
150
|
+
else:
|
|
151
|
+
self.model.zero_grad()
|
|
152
|
+
loss.backward()
|
|
153
|
+
self.optimizer.step()
|
|
99
154
|
total_loss += np.array([l.item() for l in loss_list])
|
|
100
155
|
log_dict = {"task_%d:" % (i): total_loss[i] / (iter_i + 1) for i in range(self.n_task)}
|
|
156
|
+
loss_list = [total_loss[i] / (iter_i + 1) for i in range(self.n_task)]
|
|
101
157
|
print("train loss: ", log_dict)
|
|
102
158
|
if self.loss_weight:
|
|
103
159
|
print("loss weight: ", [w.item() for w in self.loss_weight])
|
|
104
160
|
|
|
105
|
-
|
|
106
|
-
|
|
161
|
+
return loss_list
|
|
162
|
+
|
|
163
|
+
def fit(self, train_dataloader, val_dataloader, mode='base', seed=0):
|
|
164
|
+
total_log = []
|
|
165
|
+
|
|
107
166
|
for epoch_i in range(self.n_epoch):
|
|
108
|
-
self.train_one_epoch(train_dataloader)
|
|
167
|
+
_log_per_epoch = self.train_one_epoch(train_dataloader)
|
|
168
|
+
|
|
109
169
|
if self.scheduler is not None:
|
|
110
170
|
if epoch_i % self.scheduler.step_size == 0:
|
|
111
171
|
print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
|
|
112
|
-
self.scheduler.step() #update lr in epoch level by scheduler
|
|
172
|
+
self.scheduler.step() # update lr in epoch level by scheduler
|
|
113
173
|
scores = self.evaluate(self.model, val_dataloader)
|
|
114
174
|
print('epoch:', epoch_i, 'validation scores: ', scores)
|
|
175
|
+
|
|
176
|
+
for score in scores:
|
|
177
|
+
_log_per_epoch.append(score)
|
|
178
|
+
|
|
179
|
+
total_log.append(_log_per_epoch)
|
|
180
|
+
|
|
115
181
|
if self.early_stopper.stop_training(scores[self.earlystop_taskid], self.model.state_dict()):
|
|
116
182
|
print('validation best auc of main task %d: %.6f' % (self.earlystop_taskid, self.early_stopper.best_auc))
|
|
117
183
|
self.model.load_state_dict(self.early_stopper.best_weights)
|
|
118
|
-
torch.save(self.early_stopper.best_weights, os.path.join(self.model_path, "model.pth")) #save best auc model
|
|
119
184
|
break
|
|
120
185
|
|
|
186
|
+
torch.save(self.model.state_dict(), os.path.join(self.model_path, "model_{}_{}.pth".format(mode, seed))) # save best auc model
|
|
187
|
+
|
|
188
|
+
return total_log
|
|
189
|
+
|
|
121
190
|
def evaluate(self, model, data_loader):
|
|
122
191
|
model.eval()
|
|
123
192
|
targets, predicts = list(), list()
|
|
124
193
|
with torch.no_grad():
|
|
125
194
|
tk0 = tqdm.tqdm(data_loader, desc="validation", smoothing=0, mininterval=1.0)
|
|
126
195
|
for i, (x_dict, ys) in enumerate(tk0):
|
|
127
|
-
x_dict = {k: v.to(self.device) for k, v in x_dict.items()} #tensor to GPU
|
|
196
|
+
x_dict = {k: v.to(self.device) for k, v in x_dict.items()} # tensor to GPU
|
|
128
197
|
ys = ys.to(self.device)
|
|
129
198
|
y_preds = self.model(x_dict)
|
|
130
199
|
targets.extend(ys.tolist())
|
|
@@ -142,4 +211,49 @@ class MTLTrainer(object):
|
|
|
142
211
|
x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
|
|
143
212
|
y_preds = model(x_dict)
|
|
144
213
|
predicts.extend(y_preds.tolist())
|
|
145
|
-
return predicts
|
|
214
|
+
return predicts
|
|
215
|
+
|
|
216
|
+
def export_onnx(self, output_path, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False):
|
|
217
|
+
"""Export the trained multi-task model to ONNX format.
|
|
218
|
+
|
|
219
|
+
This method exports multi-task learning models (e.g., MMOE, PLE, ESMM, SharedBottom)
|
|
220
|
+
to ONNX format for deployment. The exported model will have multiple outputs
|
|
221
|
+
corresponding to each task.
|
|
222
|
+
|
|
223
|
+
Note:
|
|
224
|
+
The ONNX model will output a tensor of shape [batch_size, n_task] where
|
|
225
|
+
n_task is the number of tasks in the multi-task model.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
output_path (str): Path to save the ONNX model file.
|
|
229
|
+
dummy_input (dict, optional): Example input dict {feature_name: tensor}.
|
|
230
|
+
If not provided, dummy inputs will be generated automatically.
|
|
231
|
+
batch_size (int): Batch size for auto-generated dummy input (default: 2).
|
|
232
|
+
seq_length (int): Sequence length for SequenceFeature (default: 10).
|
|
233
|
+
opset_version (int): ONNX opset version (default: 14).
|
|
234
|
+
dynamic_batch (bool): Enable dynamic batch size (default: True).
|
|
235
|
+
device (str, optional): Device for export ('cpu', 'cuda', etc.).
|
|
236
|
+
If None, defaults to 'cpu' for maximum compatibility.
|
|
237
|
+
verbose (bool): Print export details (default: False).
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
bool: True if export succeeded, False otherwise.
|
|
241
|
+
|
|
242
|
+
Example:
|
|
243
|
+
>>> trainer = MTLTrainer(mmoe_model, task_types=["classification", "classification"], ...)
|
|
244
|
+
>>> trainer.fit(train_dl, val_dl)
|
|
245
|
+
>>> trainer.export_onnx("mmoe.onnx")
|
|
246
|
+
|
|
247
|
+
>>> # Export on specific device
|
|
248
|
+
>>> trainer.export_onnx("mmoe.onnx", device="cpu")
|
|
249
|
+
"""
|
|
250
|
+
from ..utils.onnx_export import ONNXExporter
|
|
251
|
+
|
|
252
|
+
# Handle DataParallel wrapped model
|
|
253
|
+
model = self.model.module if hasattr(self.model, 'module') else self.model
|
|
254
|
+
|
|
255
|
+
# Use provided device or default to 'cpu'
|
|
256
|
+
export_device = device if device is not None else 'cpu'
|
|
257
|
+
|
|
258
|
+
exporter = ONNXExporter(model, device=export_device)
|
|
259
|
+
return exporter.export(output_path=output_path, dummy_input=dummy_input, batch_size=batch_size, seq_length=seq_length, opset_version=opset_version, dynamic_batch=dynamic_batch, verbose=verbose)
|