torch-rechub 0.0.1__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. torch_rechub/__init__.py +14 -0
  2. torch_rechub/basic/activation.py +3 -1
  3. torch_rechub/basic/callback.py +2 -2
  4. torch_rechub/basic/features.py +38 -8
  5. torch_rechub/basic/initializers.py +92 -0
  6. torch_rechub/basic/layers.py +800 -46
  7. torch_rechub/basic/loss_func.py +223 -0
  8. torch_rechub/basic/metaoptimizer.py +76 -0
  9. torch_rechub/basic/metric.py +251 -0
  10. torch_rechub/models/generative/__init__.py +6 -0
  11. torch_rechub/models/generative/hllm.py +249 -0
  12. torch_rechub/models/generative/hstu.py +189 -0
  13. torch_rechub/models/matching/__init__.py +13 -0
  14. torch_rechub/models/matching/comirec.py +193 -0
  15. torch_rechub/models/matching/dssm.py +72 -0
  16. torch_rechub/models/matching/dssm_facebook.py +77 -0
  17. torch_rechub/models/matching/dssm_senet.py +87 -0
  18. torch_rechub/models/matching/gru4rec.py +85 -0
  19. torch_rechub/models/matching/mind.py +103 -0
  20. torch_rechub/models/matching/narm.py +82 -0
  21. torch_rechub/models/matching/sasrec.py +143 -0
  22. torch_rechub/models/matching/sine.py +148 -0
  23. torch_rechub/models/matching/stamp.py +81 -0
  24. torch_rechub/models/matching/youtube_dnn.py +75 -0
  25. torch_rechub/models/matching/youtube_sbc.py +98 -0
  26. torch_rechub/models/multi_task/__init__.py +5 -2
  27. torch_rechub/models/multi_task/aitm.py +83 -0
  28. torch_rechub/models/multi_task/esmm.py +19 -8
  29. torch_rechub/models/multi_task/mmoe.py +18 -12
  30. torch_rechub/models/multi_task/ple.py +41 -29
  31. torch_rechub/models/multi_task/shared_bottom.py +3 -2
  32. torch_rechub/models/ranking/__init__.py +13 -2
  33. torch_rechub/models/ranking/afm.py +65 -0
  34. torch_rechub/models/ranking/autoint.py +102 -0
  35. torch_rechub/models/ranking/bst.py +61 -0
  36. torch_rechub/models/ranking/dcn.py +38 -0
  37. torch_rechub/models/ranking/dcn_v2.py +59 -0
  38. torch_rechub/models/ranking/deepffm.py +131 -0
  39. torch_rechub/models/ranking/deepfm.py +8 -7
  40. torch_rechub/models/ranking/dien.py +191 -0
  41. torch_rechub/models/ranking/din.py +31 -19
  42. torch_rechub/models/ranking/edcn.py +101 -0
  43. torch_rechub/models/ranking/fibinet.py +42 -0
  44. torch_rechub/models/ranking/widedeep.py +6 -6
  45. torch_rechub/trainers/__init__.py +4 -2
  46. torch_rechub/trainers/ctr_trainer.py +191 -0
  47. torch_rechub/trainers/match_trainer.py +239 -0
  48. torch_rechub/trainers/matching.md +3 -0
  49. torch_rechub/trainers/mtl_trainer.py +137 -23
  50. torch_rechub/trainers/seq_trainer.py +293 -0
  51. torch_rechub/utils/__init__.py +0 -0
  52. torch_rechub/utils/data.py +492 -0
  53. torch_rechub/utils/hstu_utils.py +198 -0
  54. torch_rechub/utils/match.py +457 -0
  55. torch_rechub/utils/mtl.py +136 -0
  56. torch_rechub/utils/onnx_export.py +353 -0
  57. torch_rechub-0.0.4.dist-info/METADATA +391 -0
  58. torch_rechub-0.0.4.dist-info/RECORD +62 -0
  59. {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.4.dist-info}/WHEEL +1 -2
  60. {torch_rechub-0.0.1.dist-info → torch_rechub-0.0.4.dist-info/licenses}/LICENSE +1 -1
  61. torch_rechub/basic/utils.py +0 -168
  62. torch_rechub/trainers/trainer.py +0 -111
  63. torch_rechub-0.0.1.dist-info/METADATA +0 -105
  64. torch_rechub-0.0.1.dist-info/RECORD +0 -26
  65. torch_rechub-0.0.1.dist-info/top_level.txt +0 -1
@@ -0,0 +1,239 @@
1
+ import os
2
+
3
+ import torch
4
+ import tqdm
5
+ from sklearn.metrics import roc_auc_score
6
+
7
+ from ..basic.callback import EarlyStopper
8
+ from ..basic.loss_func import BPRLoss, RegularizationLoss
9
+
10
+
11
+ class MatchTrainer(object):
12
+ """A general trainer for Matching/Retrieval
13
+
14
+ Args:
15
+ model (nn.Module): any matching model.
16
+ mode (int, optional): the training mode, `{0:point-wise, 1:pair-wise, 2:list-wise}`. Defaults to 0.
17
+ optimizer_fn (torch.optim): optimizer function of pytorch (default = `torch.optim.Adam`).
18
+ optimizer_params (dict): parameters of optimizer_fn.
19
+ scheduler_fn (torch.optim.lr_scheduler) : torch scheduling class, eg. `torch.optim.lr_scheduler.StepLR`.
20
+ scheduler_params (dict): parameters of optimizer scheduler_fn.
21
+ n_epoch (int): epoch number of training.
22
+ earlystop_patience (int): how long to wait after last time validation auc improved (default=10).
23
+ device (str): `"cpu"` or `"cuda:0"`
24
+ gpus (list): id of multi gpu (default=[]). If the length >=1, then the model will wrapped by nn.DataParallel.
25
+ model_path (str): the path you want to save the model (default="./"). Note only save the best weight in the validation data.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ model,
31
+ mode=0,
32
+ optimizer_fn=torch.optim.Adam,
33
+ optimizer_params=None,
34
+ regularization_params=None,
35
+ scheduler_fn=None,
36
+ scheduler_params=None,
37
+ n_epoch=10,
38
+ earlystop_patience=10,
39
+ device="cpu",
40
+ gpus=None,
41
+ model_path="./",
42
+ ):
43
+ self.model = model # for uniform weights save method in one gpu or multi gpu
44
+ if gpus is None:
45
+ gpus = []
46
+ self.gpus = gpus
47
+ if len(gpus) > 1:
48
+ print('parallel running on these gpus:', gpus)
49
+ self.model = torch.nn.DataParallel(self.model, device_ids=gpus)
50
+ # torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
51
+ self.device = torch.device(device)
52
+ self.model.to(self.device)
53
+ if optimizer_params is None:
54
+ optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
55
+ if regularization_params is None:
56
+ regularization_params = {"embedding_l1": 0.0, "embedding_l2": 0.0, "dense_l1": 0.0, "dense_l2": 0.0}
57
+ self.mode = mode
58
+ if mode == 0: # point-wise loss, binary cross_entropy
59
+ self.criterion = torch.nn.BCELoss() # default loss binary cross_entropy
60
+ elif mode == 1: # pair-wise loss
61
+ self.criterion = BPRLoss()
62
+ elif mode == 2: # list-wise loss, softmax
63
+ self.criterion = torch.nn.CrossEntropyLoss()
64
+ else:
65
+ raise ValueError("mode only contain value in %s, but got %s" % ([0, 1, 2], mode))
66
+ self.optimizer = optimizer_fn(self.model.parameters(), **optimizer_params) # default optimizer
67
+ self.scheduler = None
68
+ if scheduler_fn is not None:
69
+ self.scheduler = scheduler_fn(self.optimizer, **scheduler_params)
70
+ self.evaluate_fn = roc_auc_score # default evaluate function
71
+ self.n_epoch = n_epoch
72
+ self.early_stopper = EarlyStopper(patience=earlystop_patience)
73
+ self.model_path = model_path
74
+ # Initialize regularization loss
75
+ self.reg_loss_fn = RegularizationLoss(**regularization_params)
76
+
77
+ def train_one_epoch(self, data_loader, log_interval=10):
78
+ self.model.train()
79
+ total_loss = 0
80
+ tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
81
+ for i, (x_dict, y) in enumerate(tk0):
82
+ x_dict = {k: v.to(self.device) for k, v in x_dict.items()} # tensor to GPU
83
+ y = y.to(self.device)
84
+ if self.mode == 0:
85
+ y = y.float() # torch._C._nn.binary_cross_entropy expected Float
86
+ else:
87
+ y = y.long() #
88
+ if self.mode == 1: # pair_wise
89
+ pos_score, neg_score = self.model(x_dict)
90
+ loss = self.criterion(pos_score, neg_score)
91
+ else:
92
+ y_pred = self.model(x_dict)
93
+ loss = self.criterion(y_pred, y)
94
+
95
+ # Add regularization loss
96
+ reg_loss = self.reg_loss_fn(self.model)
97
+ loss = loss + reg_loss
98
+
99
+ # used for debug
100
+ # if i == 0:
101
+ # print()
102
+ # if self.mode == 0:
103
+ # print('pred: ', [f'{float(each):5.2g}' for each in y_pred.detach().cpu().tolist()])
104
+ # print('truth:', [f'{float(each):5.2g}' for each in y.detach().cpu().tolist()])
105
+ # elif self.mode == 2:
106
+ # pred = y_pred.detach().cpu().mean(0)
107
+ # pred = torch.softmax(pred, dim=0).tolist()
108
+ # print('pred: ', [f'{float(each):4.2g}' for each in pred])
109
+ # elif self.mode == 1:
110
+ # print('pos:', [f'{float(each):5.2g}' for each in pos_score.detach().cpu().tolist()])
111
+ # print('neg: ', [f'{float(each):5.2g}' for each in neg_score.detach().cpu().tolist()])
112
+
113
+ self.model.zero_grad()
114
+ loss.backward()
115
+ self.optimizer.step()
116
+ total_loss += loss.item()
117
+ if (i + 1) % log_interval == 0:
118
+ tk0.set_postfix(loss=total_loss / log_interval)
119
+ total_loss = 0
120
+
121
+ def fit(self, train_dataloader, val_dataloader=None):
122
+ for epoch_i in range(self.n_epoch):
123
+ print('epoch:', epoch_i)
124
+ self.train_one_epoch(train_dataloader)
125
+ if self.scheduler is not None:
126
+ if epoch_i % self.scheduler.step_size == 0:
127
+ print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
128
+ self.scheduler.step() # update lr in epoch level by scheduler
129
+
130
+ if val_dataloader:
131
+ auc = self.evaluate(self.model, val_dataloader)
132
+ print('epoch:', epoch_i, 'validation: auc:', auc)
133
+ if self.early_stopper.stop_training(auc, self.model.state_dict()):
134
+ print(f'validation: best auc: {self.early_stopper.best_auc}')
135
+ self.model.load_state_dict(self.early_stopper.best_weights)
136
+ break
137
+ torch.save(self.model.state_dict(), os.path.join(self.model_path, "model.pth")) # save best auc model
138
+
139
+ def evaluate(self, model, data_loader):
140
+ model.eval()
141
+ targets, predicts = list(), list()
142
+ with torch.no_grad():
143
+ tk0 = tqdm.tqdm(data_loader, desc="validation", smoothing=0, mininterval=1.0)
144
+ for i, (x_dict, y) in enumerate(tk0):
145
+ x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
146
+ y = y.to(self.device)
147
+ y_pred = model(x_dict)
148
+ targets.extend(y.tolist())
149
+ predicts.extend(y_pred.tolist())
150
+ return self.evaluate_fn(targets, predicts)
151
+
152
+ def predict(self, model, data_loader):
153
+ model.eval()
154
+ predicts = list()
155
+ with torch.no_grad():
156
+ tk0 = tqdm.tqdm(data_loader, desc="predict", smoothing=0, mininterval=1.0)
157
+ for i, (x_dict, y) in enumerate(tk0):
158
+ x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
159
+ y = y.to(self.device)
160
+ y_pred = model(x_dict)
161
+ predicts.extend(y_pred.tolist())
162
+ return predicts
163
+
164
+ def inference_embedding(self, model, mode, data_loader, model_path):
165
+ # inference
166
+ assert mode in ["user", "item"], "Invalid mode={}.".format(mode)
167
+ model.mode = mode
168
+ model.load_state_dict(torch.load(os.path.join(model_path, "model.pth"), map_location=self.device, weights_only=True))
169
+ model = model.to(self.device)
170
+ model.eval()
171
+ predicts = []
172
+ with torch.no_grad():
173
+ tk0 = tqdm.tqdm(data_loader, desc="%s inference" % (mode), smoothing=0, mininterval=1.0)
174
+ for i, x_dict in enumerate(tk0):
175
+ x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
176
+ y_pred = model(x_dict)
177
+ predicts.append(y_pred.data)
178
+ return torch.cat(predicts, dim=0)
179
+
180
+ def export_onnx(self, output_path, mode=None, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False):
181
+ """Export the trained matching model to ONNX format.
182
+
183
+ This method exports matching/retrieval models (e.g., DSSM, YoutubeDNN, MIND)
184
+ to ONNX format. For dual-tower models, you can export user tower and item
185
+ tower separately for efficient online serving.
186
+
187
+ Args:
188
+ output_path (str): Path to save the ONNX model file.
189
+ mode (str, optional): Export mode for dual-tower models:
190
+ - "user": Export only the user tower (for user embedding inference)
191
+ - "item": Export only the item tower (for item embedding inference)
192
+ - None: Export the full model (default)
193
+ dummy_input (dict, optional): Example input dict {feature_name: tensor}.
194
+ If not provided, dummy inputs will be generated automatically.
195
+ batch_size (int): Batch size for auto-generated dummy input (default: 2).
196
+ seq_length (int): Sequence length for SequenceFeature (default: 10).
197
+ opset_version (int): ONNX opset version (default: 14).
198
+ dynamic_batch (bool): Enable dynamic batch size (default: True).
199
+ device (str, optional): Device for export ('cpu', 'cuda', etc.).
200
+ If None, defaults to 'cpu' for maximum compatibility.
201
+ verbose (bool): Print export details (default: False).
202
+
203
+ Returns:
204
+ bool: True if export succeeded, False otherwise.
205
+
206
+ Example:
207
+ >>> trainer = MatchTrainer(dssm_model, mode=0, ...)
208
+ >>> trainer.fit(train_dl)
209
+
210
+ >>> # Export user tower for user embedding inference
211
+ >>> trainer.export_onnx("user_tower.onnx", mode="user")
212
+
213
+ >>> # Export item tower for item embedding inference
214
+ >>> trainer.export_onnx("item_tower.onnx", mode="item")
215
+
216
+ >>> # Export full model (for online similarity computation)
217
+ >>> trainer.export_onnx("full_model.onnx")
218
+
219
+ >>> # Export on specific device
220
+ >>> trainer.export_onnx("user_tower.onnx", mode="user", device="cpu")
221
+ """
222
+ from ..utils.onnx_export import ONNXExporter
223
+
224
+ # Handle DataParallel wrapped model
225
+ model = self.model.module if hasattr(self.model, 'module') else self.model
226
+
227
+ # Store original mode
228
+ original_mode = getattr(model, 'mode', None)
229
+
230
+ # Use provided device or default to 'cpu'
231
+ export_device = device if device is not None else 'cpu'
232
+
233
+ try:
234
+ exporter = ONNXExporter(model, device=export_device)
235
+ return exporter.export(output_path=output_path, mode=mode, dummy_input=dummy_input, batch_size=batch_size, seq_length=seq_length, opset_version=opset_version, dynamic_batch=dynamic_batch, verbose=verbose)
236
+ finally:
237
+ # Restore original mode
238
+ if hasattr(model, 'mode'):
239
+ model.mode = original_mode
@@ -0,0 +1,3 @@
1
+ # Matching
2
+
3
+ 召回使用文档
@@ -1,11 +1,15 @@
1
1
  import os
2
- import tqdm
2
+
3
3
  import numpy as np
4
4
  import torch
5
5
  import torch.nn as nn
6
+ import tqdm
7
+
6
8
  from ..basic.callback import EarlyStopper
7
- from ..basic.utils import get_loss_func, get_metric_func
9
+ from ..basic.loss_func import RegularizationLoss
8
10
  from ..models.multi_task import ESMM
11
+ from ..utils.data import get_loss_func, get_metric_func
12
+ from ..utils.mtl import MetaBalance, gradnorm, shared_task_layers
9
13
 
10
14
 
11
15
  class MTLTrainer(object):
@@ -18,7 +22,7 @@ class MTLTrainer(object):
18
22
  optimizer_params (dict): parameters of optimizer_fn.
19
23
  scheduler_fn (torch.optim.lr_scheduler) : torch scheduling class, eg. `torch.optim.lr_scheduler.StepLR`.
20
24
  scheduler_params (dict): parameters of optimizer scheduler_fn.
21
- adaptive_params (dict): parameters of adaptive loss weight method. Now only support `{"method" : "uwl"}`.
25
+ adaptive_params (dict): parameters of adaptive loss weight method. Now only support `{"method" : "uwl"}`.
22
26
  n_epoch (int): epoch number of training.
23
27
  earlystop_taskid (int): task id of earlystop metrics relies between multi task (default = 0).
24
28
  earlystop_patience (int): how long to wait after last time validation auc improved (default = 10).
@@ -32,10 +36,8 @@ class MTLTrainer(object):
32
36
  model,
33
37
  task_types,
34
38
  optimizer_fn=torch.optim.Adam,
35
- optimizer_params={
36
- "lr": 1e-3,
37
- "weight_decay": 1e-5
38
- },
39
+ optimizer_params=None,
40
+ regularization_params=None,
39
41
  scheduler_fn=None,
40
42
  scheduler_params=None,
41
43
  adaptive_params=None,
@@ -43,10 +45,16 @@ class MTLTrainer(object):
43
45
  earlystop_taskid=0,
44
46
  earlystop_patience=10,
45
47
  device="cpu",
46
- gpus=[],
48
+ gpus=None,
47
49
  model_path="./",
48
50
  ):
49
51
  self.model = model
52
+ if gpus is None:
53
+ gpus = []
54
+ if optimizer_params is None:
55
+ optimizer_params = {"lr": 1e-3, "weight_decay": 1e-5}
56
+ if regularization_params is None:
57
+ regularization_params = {"embedding_l1": 0.0, "embedding_l2": 0.0, "dense_l1": 0.0, "dense_l2": 0.0}
50
58
  self.task_types = task_types
51
59
  self.n_task = len(task_types)
52
60
  self.loss_weight = None
@@ -56,7 +64,27 @@ class MTLTrainer(object):
56
64
  self.adaptive_method = "uwl"
57
65
  self.loss_weight = nn.ParameterList(nn.Parameter(torch.zeros(1)) for _ in range(self.n_task))
58
66
  self.model.add_module("loss weight", self.loss_weight)
59
- self.optimizer = optimizer_fn(self.model.parameters(), **optimizer_params) #default Adam optimizer
67
+ elif adaptive_params["method"] == "metabalance":
68
+ self.adaptive_method = "metabalance"
69
+ share_layers, task_layers = shared_task_layers(self.model)
70
+ self.meta_optimizer = MetaBalance(share_layers)
71
+ self.share_optimizer = optimizer_fn(share_layers, **optimizer_params)
72
+ self.task_optimizer = optimizer_fn(task_layers, **optimizer_params)
73
+ elif adaptive_params["method"] == "gradnorm":
74
+ self.adaptive_method = "gradnorm"
75
+ self.alpha = adaptive_params.get("alpha", 0.16)
76
+ share_layers = shared_task_layers(self.model)[0]
77
+ # gradnorm calculate the gradients of each loss on the last
78
+ # fully connected shared layer weight(dimension is 2)
79
+ for i in range(len(share_layers)):
80
+ if share_layers[-i].ndim == 2:
81
+ self.last_share_layer = share_layers[-i]
82
+ break
83
+ self.initial_task_loss = None
84
+ self.loss_weight = nn.ParameterList(nn.Parameter(torch.ones(1)) for _ in range(self.n_task))
85
+ self.model.add_module("loss weight", self.loss_weight)
86
+ if self.adaptive_method != "metabalance":
87
+ self.optimizer = optimizer_fn(self.model.parameters(), **optimizer_params) # default Adam optimizer
60
88
  self.scheduler = None
61
89
  if scheduler_fn is not None:
62
90
  self.scheduler = scheduler_fn(self.optimizer, **scheduler_params)
@@ -65,27 +93,32 @@ class MTLTrainer(object):
65
93
  self.n_epoch = n_epoch
66
94
  self.earlystop_taskid = earlystop_taskid
67
95
  self.early_stopper = EarlyStopper(patience=earlystop_patience)
68
- self.device = torch.device(device)
69
96
 
70
97
  self.gpus = gpus
71
98
  if len(gpus) > 1:
72
99
  print('parallel running on these gpus:', gpus)
73
100
  self.model = torch.nn.DataParallel(self.model, device_ids=gpus)
101
+ # torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
102
+ self.device = torch.device(device)
103
+ self.model.to(self.device)
74
104
  self.model_path = model_path
105
+ # Initialize regularization loss
106
+ self.reg_loss_fn = RegularizationLoss(**regularization_params)
75
107
 
76
108
  def train_one_epoch(self, data_loader):
77
109
  self.model.train()
78
110
  total_loss = np.zeros(self.n_task)
79
111
  tk0 = tqdm.tqdm(data_loader, desc="train", smoothing=0, mininterval=1.0)
80
112
  for iter_i, (x_dict, ys) in enumerate(tk0):
81
- x_dict = {k: v.to(self.device) for k, v in x_dict.items()} #tensor to GPU
113
+ x_dict = {k: v.to(self.device) for k, v in x_dict.items()} # tensor to GPU
82
114
  ys = ys.to(self.device)
83
115
  y_preds = self.model(x_dict)
84
116
  loss_list = [self.loss_fns[i](y_preds[:, i], ys[:, i].float()) for i in range(self.n_task)]
85
117
  if isinstance(self.model, ESMM):
86
- loss = sum(loss_list[1:]) #ESSM only compute loss for ctr and ctcvr task
118
+ # ESSM only compute loss for ctr and ctcvr task
119
+ loss = sum(loss_list[1:])
87
120
  else:
88
- if self.adaptive_method != None:
121
+ if self.adaptive_method is not None:
89
122
  if self.adaptive_method == "uwl":
90
123
  loss = 0
91
124
  for loss_i, w_i in zip(loss_list, self.loss_weight):
@@ -93,38 +126,74 @@ class MTLTrainer(object):
93
126
  loss += 2 * loss_i * torch.exp(-w_i) + w_i
94
127
  else:
95
128
  loss = sum(loss_list) / self.n_task
96
- self.model.zero_grad()
97
- loss.backward()
98
- self.optimizer.step()
129
+
130
+ # Add regularization loss
131
+ reg_loss = self.reg_loss_fn(self.model)
132
+ loss = loss + reg_loss
133
+ if self.adaptive_method == 'metabalance':
134
+ self.share_optimizer.zero_grad()
135
+ self.task_optimizer.zero_grad()
136
+ self.meta_optimizer.step(loss_list)
137
+ self.share_optimizer.step()
138
+ self.task_optimizer.step()
139
+ elif self.adaptive_method == "gradnorm":
140
+ self.optimizer.zero_grad()
141
+ if self.initial_task_loss is None:
142
+ self.initial_task_loss = [l.item() for l in loss_list]
143
+ gradnorm(loss_list, self.loss_weight, self.last_share_layer, self.initial_task_loss, self.alpha)
144
+ self.optimizer.step()
145
+ # renormalize
146
+ loss_weight_sum = sum([w.item() for w in self.loss_weight])
147
+ normalize_coeff = len(self.loss_weight) / loss_weight_sum
148
+ for w in self.loss_weight:
149
+ w.data = w.data * normalize_coeff
150
+ else:
151
+ self.model.zero_grad()
152
+ loss.backward()
153
+ self.optimizer.step()
99
154
  total_loss += np.array([l.item() for l in loss_list])
100
155
  log_dict = {"task_%d:" % (i): total_loss[i] / (iter_i + 1) for i in range(self.n_task)}
156
+ loss_list = [total_loss[i] / (iter_i + 1) for i in range(self.n_task)]
101
157
  print("train loss: ", log_dict)
102
158
  if self.loss_weight:
103
159
  print("loss weight: ", [w.item() for w in self.loss_weight])
104
160
 
105
- def fit(self, train_dataloader, val_dataloader):
106
- self.model.to(self.device)
161
+ return loss_list
162
+
163
+ def fit(self, train_dataloader, val_dataloader, mode='base', seed=0):
164
+ total_log = []
165
+
107
166
  for epoch_i in range(self.n_epoch):
108
- self.train_one_epoch(train_dataloader)
167
+ _log_per_epoch = self.train_one_epoch(train_dataloader)
168
+
109
169
  if self.scheduler is not None:
110
170
  if epoch_i % self.scheduler.step_size == 0:
111
171
  print("Current lr : {}".format(self.optimizer.state_dict()['param_groups'][0]['lr']))
112
- self.scheduler.step() #update lr in epoch level by scheduler
172
+ self.scheduler.step() # update lr in epoch level by scheduler
113
173
  scores = self.evaluate(self.model, val_dataloader)
114
174
  print('epoch:', epoch_i, 'validation scores: ', scores)
175
+
176
+ for score in scores:
177
+ _log_per_epoch.append(score)
178
+
179
+ total_log.append(_log_per_epoch)
180
+
115
181
  if self.early_stopper.stop_training(scores[self.earlystop_taskid], self.model.state_dict()):
116
182
  print('validation best auc of main task %d: %.6f' % (self.earlystop_taskid, self.early_stopper.best_auc))
117
183
  self.model.load_state_dict(self.early_stopper.best_weights)
118
- torch.save(self.early_stopper.best_weights, os.path.join(self.model_path, "model.pth")) #save best auc model
119
184
  break
120
185
 
186
+ torch.save(self.model.state_dict(), os.path.join(self.model_path, "model_{}_{}.pth".format(mode, seed))) # save best auc model
187
+
188
+ return total_log
189
+
121
190
  def evaluate(self, model, data_loader):
122
191
  model.eval()
123
192
  targets, predicts = list(), list()
124
193
  with torch.no_grad():
125
194
  tk0 = tqdm.tqdm(data_loader, desc="validation", smoothing=0, mininterval=1.0)
126
195
  for i, (x_dict, ys) in enumerate(tk0):
127
- x_dict = {k: v.to(self.device) for k, v in x_dict.items()} #tensor to GPU
196
+ x_dict = {k: v.to(self.device) for k, v in x_dict.items()} # tensor to GPU
128
197
  ys = ys.to(self.device)
129
198
  y_preds = self.model(x_dict)
130
199
  targets.extend(ys.tolist())
@@ -142,4 +211,49 @@ class MTLTrainer(object):
142
211
  x_dict = {k: v.to(self.device) for k, v in x_dict.items()}
143
212
  y_preds = model(x_dict)
144
213
  predicts.extend(y_preds.tolist())
145
- return predicts
214
+ return predicts
215
+
216
+ def export_onnx(self, output_path, dummy_input=None, batch_size=2, seq_length=10, opset_version=14, dynamic_batch=True, device=None, verbose=False):
217
+ """Export the trained multi-task model to ONNX format.
218
+
219
+ This method exports multi-task learning models (e.g., MMOE, PLE, ESMM, SharedBottom)
220
+ to ONNX format for deployment. The exported model will have multiple outputs
221
+ corresponding to each task.
222
+
223
+ Note:
224
+ The ONNX model will output a tensor of shape [batch_size, n_task] where
225
+ n_task is the number of tasks in the multi-task model.
226
+
227
+ Args:
228
+ output_path (str): Path to save the ONNX model file.
229
+ dummy_input (dict, optional): Example input dict {feature_name: tensor}.
230
+ If not provided, dummy inputs will be generated automatically.
231
+ batch_size (int): Batch size for auto-generated dummy input (default: 2).
232
+ seq_length (int): Sequence length for SequenceFeature (default: 10).
233
+ opset_version (int): ONNX opset version (default: 14).
234
+ dynamic_batch (bool): Enable dynamic batch size (default: True).
235
+ device (str, optional): Device for export ('cpu', 'cuda', etc.).
236
+ If None, defaults to 'cpu' for maximum compatibility.
237
+ verbose (bool): Print export details (default: False).
238
+
239
+ Returns:
240
+ bool: True if export succeeded, False otherwise.
241
+
242
+ Example:
243
+ >>> trainer = MTLTrainer(mmoe_model, task_types=["classification", "classification"], ...)
244
+ >>> trainer.fit(train_dl, val_dl)
245
+ >>> trainer.export_onnx("mmoe.onnx")
246
+
247
+ >>> # Export on specific device
248
+ >>> trainer.export_onnx("mmoe.onnx", device="cpu")
249
+ """
250
+ from ..utils.onnx_export import ONNXExporter
251
+
252
+ # Handle DataParallel wrapped model
253
+ model = self.model.module if hasattr(self.model, 'module') else self.model
254
+
255
+ # Use provided device or default to 'cpu'
256
+ export_device = device if device is not None else 'cpu'
257
+
258
+ exporter = ONNXExporter(model, device=export_device)
259
+ return exporter.export(output_path=output_path, dummy_input=dummy_input, batch_size=batch_size, seq_length=seq_length, opset_version=opset_version, dynamic_batch=dynamic_batch, verbose=verbose)